-
Notifications
You must be signed in to change notification settings - Fork 129
Open
Description
I'm trying to run inference using a pretrained diffwave model on the output of a SepFormer model (separating a 2 speaker mixture). Creating the mel spectrogram and calling predict
mel_args = {
'sample_rate': 8000,
'win_length': 384,
'hop_length': 192,
'n_fft': 384,
'f_min': 80.0,
'f_max': 3000.0,
'n_mels': 64,
'power': 2.0,
'normalized': False,
}
mel_transform = MelSpectrogram(**mel_args)
mel_spec = mel_transform(det_estimation) # [32, 2, 64, 167] det_estimation.shape is [32, 2, 32000]
mel_spec = 20 * log10(clamp(mel_spec, min=1e-5)) - 20
mel_spec = clamp((mel_spec + 100) / 100, 0.0, 1.0)
mel_trimmed = mel_spec[:, :, :, :-1] # [32, 2, 64, 166]
B, C, nmel, ntime = mel_trimmed.shape # B: 32 C:2 nmel: 64 ntime:64
enlarged_spectrogram = mel_trimmed.view(B * C, nmel, ntime) # [64, 64, 166]
gen_estimation, sr = predict(enlarged_spectrogram, 'diffwave-weights-902319.pt', base_params, fast_sampling=True, device='cpu')
leads to the following error:
Traceback (most recent call last):
File "path/test.py", line 29, in <module>
gen_estimation, _ = diffwave_predict(enlarged_spectrogram, 'diffwave-weights-902319.pt', base_params, fast_sampling=True,
File "path/diffwave/inference.py", line 81, in predict
audio = c1 * (audio - c2 * model(audio, spectrogram, torch.tensor([T[n]], device=audio.device)).squeeze(1))
File "/home/eitzo/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "path/diffwave/model.py", line 152, in forward
diffusion_step = self.diffusion_embedding(diffusion_step)
File "/home/eitzo/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "path/diffwave/model.py", line 50, in forward
x = self._lerp_embedding(diffusion_step)
File "path/diffwave/model.py", line 62, in _lerp_embedding
return low + (high - low) * (t - low_idx)
RuntimeError: The size of tensor a (128) must match the size of tensor b (166) at non-singleton dimension 3
Process finished with exit code 1
Any ideas where I went wrong ? Thanks in advance!
Metadata
Metadata
Assignees
Labels
No labels