[whisper ファインチューニング]RuntimeError:Given groups=1, weight of size [1280, 128, 3], expected input[1, ・

実現したいこと

・whisper を自前のデータセットでファインチューニングしたい

前提

発生している問題・エラーメッセージ

--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-31-864a6af94846> in <cell line: 9>() 9 for i, f in enumerate(fileList): 10 mel = log_mel_spectrogram(f) ---> 11 probs = model.detect_language(mel) 12 print(probs) 13 # 30秒データに整形 7 frames /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs) 113 def decorate_context(*args, **kwargs): 114 with ctx_factory(): --> 115 return func(*args, **kwargs) 116 117 return decorate_context /usr/local/lib/python3.10/dist-packages/whisper/decoding.py in detect_language(model, mel, tokenizer) 50 # skip encoder forward pass if already-encoded audio features were given 51 if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): ---> 52 mel = model.encoder(mel) 53 54 # forward pass using a single token, startoftranscript /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1499 if recording_scopes: 1500 # type ignore was added because at this point one knows that -> 1501 # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] 1502 name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950 1503 if name: /usr/local/lib/python3.10/dist-packages/whisper/model.py in forward(self, x) 160 the mel spectrogram of the audio 161 """ --> 162 x = F.gelu(self.conv1(x)) 163 x = F.gelu(self.conv2(x)) 164 x = x.permute(0, 2, 1) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1499 if recording_scopes: 1500 # type ignore was added because at this point one knows that -> 1501 # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] 1502 name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950 1503 if name: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py in forward(self, input) 311 312 --> 313 class Conv2d(_ConvNd): 314 __doc__ = r"""Applies a 2D convolution over an input signal composed of several input 315 planes. /usr/local/lib/python3.10/dist-packages/whisper/model.py in _conv_forward(self, x, weight, bias) 46 self, x: Tensor, weight: Tensor, bias: Optional[Tensor] 47 ) -> Tensor: ---> 48 return super()._conv_forward( 49 x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 50 ) /usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias) 307 self.padding, self.dilation, self.groups) 308 --> 309 def forward(self, input: Tensor) -> Tensor: 310 return self._conv_forward(input, self.weight, self.bias) 311 RuntimeError: Given groups=1, weight of size [1280, 128, 3], expected input[1, 80, 3000] to have 128 channels, but got 80 channels instead

該当のソースコード

python

1from datasets import Dataset,Audio 2from whisper.audio import N_FRAMES, pad_or_trim, log_mel_spectrogram 3import whisper 4from whisper.tokenizer import get_tokenizer 5 6datasets = Dataset.from_dict({"audio": fileList}).cast_column("audio", Audio())7predict_data = []8for i, f in enumerate(fileList):9 mel = log_mel_spectrogram(f)10 # 30秒データに整形11 segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(torch.float16)12 # デコード13 result = model.decode(segment)14 # トークナイザ取得`15 tokenizer = get_tokenizer(multilingual=True, language="ja", task="transcribe")16 # トークナイザのデコード17 outputText = tokenizer.decode(result.tokens)18 predict_data.append({**datasets[i], "sentence":outputText})

試したこと

mel をprintでデバックし確認した。
しかし初学者には難しく良く分からなかった。

chatGPTに助けを求めたが、チンプンカンプンなプログラムばかり出力しこの問題においては対処できなかった。

#追記(2023年12月2日:14時05分)
チャンネル数の不一致によっておこるエラーで、4チャンネルにすることで解決できるという記事を発見した
PyTorchでよくあるエラーの対処方法(次元やチャンネル数)

segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(torch.float16)

この部分を4チャンネルにすればいいのではと考えたものの、解決策は見つからなかった。

補足情報(FW/ツールのバージョンなど)

自作データセットでWhisperをファインチューニングしたら、独自用語だらけのクラロワ実況でも使えるようになった:「データセット作成編」
↑このサイトにある通りのプログラムでやっています。
音声は、ITAコーパス(emoNormal)の録音音声を1つのwavファイルに結合したものを使用している。
Hugging FaceでOpenAIの音声認識”Whisper”をFine Tuningする方法が公開されました

コメントを投稿

0 コメント