[whisper ファインチューニング]RuntimeError:Given groups=1, weight of size [1280, 128, 3], expected input[1, ・


・whisper を自前のデータセットでファインチューニングしたい



--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-31-864a6af94846> in <cell line: 9>() 9 for i, f in enumerate(fileList): 10 mel = log_mel_spectrogram(f) ---> 11 probs = model.detect_language(mel) 12 print(probs) 13 # 30秒データに整形 7 frames /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs) 113 def decorate_context(*args, **kwargs): 114 with ctx_factory(): --> 115 return func(*args, **kwargs) 116 117 return decorate_context /usr/local/lib/python3.10/dist-packages/whisper/decoding.py in detect_language(model, mel, tokenizer) 50 # skip encoder forward pass if already-encoded audio features were given 51 if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): ---> 52 mel = model.encoder(mel) 53 54 # forward pass using a single token, startoftranscript /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1499 if recording_scopes: 1500 # type ignore was added because at this point one knows that -> 1501 # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] 1502 name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950 1503 if name: /usr/local/lib/python3.10/dist-packages/whisper/model.py in forward(self, x) 160 the mel spectrogram of the audio 161 """ --> 162 x = F.gelu(self.conv1(x)) 163 x = F.gelu(self.conv2(x)) 164 x = x.permute(0, 2, 1) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1499 if recording_scopes: 1500 # type ignore was added because at this point one knows that -> 1501 # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] 1502 name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950 1503 if name: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py in forward(self, input) 311 312 --> 313 class Conv2d(_ConvNd): 314 __doc__ = r"""Applies a 2D convolution over an input signal composed of several input 315 planes. /usr/local/lib/python3.10/dist-packages/whisper/model.py in _conv_forward(self, x, weight, bias) 46 self, x: Tensor, weight: Tensor, bias: Optional[Tensor] 47 ) -> Tensor: ---> 48 return super()._conv_forward( 49 x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 50 ) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias) 307 self.padding, self.dilation, self.groups) 308 --> 309 def forward(self, input: Tensor) -> Tensor: 310 return self._conv_forward(input, self.weight, self.bias) 311 RuntimeError: Given groups=1, weight of size [1280, 128, 3], expected input[1, 80, 3000] to have 128 channels, but got 80 channels instead



1from datasets import Dataset,Audio 2from whisper.audio import N_FRAMES, pad_or_trim, log_mel_spectrogram 3import whisper 4from whisper.tokenizer import get_tokenizer 5 6datasets = Dataset.from_dict({"audio": fileList}).cast_column("audio", Audio())7predict_data = []8for i, f in enumerate(fileList):9 mel = log_mel_spectrogram(f)10 # 30秒データに整形11 segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(torch.float16)12 # デコード13 result = model.decode(segment)14 # トークナイザ取得`15 tokenizer = get_tokenizer(multilingual=True, language="ja", task="transcribe")16 # トークナイザのデコード17 outputText = tokenizer.decode(result.tokens)18 predict_data.append({**datasets[i], "sentence":outputText})


mel をprintでデバックし確認した。



segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(torch.float16)



Hugging FaceでOpenAIの音声認識”Whisper”をFine Tuningする方法が公開されました


