RuntimeError: 対象OSGAN

実現したいこと

学習済みモデルから偽画像を生成してみたい

前提

https://github.com/zju-vipa/OSGAN
上記のURLを実装し、学習させ保存した学習済みモデルを用いて、偽画像の生成を実行しようとした際にエラーが発生しました。
画像サイズは64*64

発生している問題・エラーメッセージ

python

1エラーメッセージ 2EN:RuntimeError: Error(s) in loading state_dict for DCGenerator:3 size mismatch for deconv1.weight: copying a param with shape torch.Size([128, 1024, 4, 4]) from checkpoint, the shape in current model is torch.Size([100, 1024, 4, 4]).4 size mismatch for deconv5.weight: copying a param with shape torch.Size([128, 3, 4, 4]) from checkpoint, the shape in current model is torch.Size([128, 1, 4, 4]).5 size mismatch for deconv5.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([1]).6 7JP:RuntimeError: DCGenerator の state_dict のロード中にエラーが発生しました:8deconv1.weight のサイズが一致しません: torch.Size([128, 1024, 4, 4]) の形状を持つパラメータをチェックポイントからコピーしています。現在のモデルの形状は torch.Size([100, 1024, 4, 4]) です。 9deconv5.weight のサイズが一致しません: torch.Size([128, 3, 4, 4]) の形状を持つパラメータをチェックポイントからコピーすると、現在のモデルの形状は torch.Size([128, 1, 4, 4]) になります。 10deconv5.bias のサイズが一致しません: torch.Size([3]) の形状を持つパラメータをチェックポイントからコピーします。現在のモデルの形状は torch.Size([1]) です。

該当のソースコード

python

1class DCGenerator(nn.Module):2 def __init__(self, zdim=100, num_channel=1, d=128):3 super(DCGenerator, self).__init__()4 self.zdim = zdim 5 self.deconv1 = nn.ConvTranspose2d(zdim, d * 8, 4, 1, 0)6 self.deconv1_bn = nn.BatchNorm2d(d * 8)7 self.deconv2 = nn.ConvTranspose2d(d * 8, d * 4, 4, 2, 1)8 self.deconv2_bn = nn.BatchNorm2d(d * 4)9 self.deconv3 = nn.ConvTranspose2d(d * 4, d * 2, 4, 2, 1)10 self.deconv3_bn = nn.BatchNorm2d(d * 2)11 self.deconv4 = nn.ConvTranspose2d(d * 2, d, 4, 2, 1)12 self.deconv4_bn = nn.BatchNorm2d(d)13 self.deconv5 = nn.ConvTranspose2d(d, num_channel, 4, 2, 1)14 15 def weight_init(self, mean, std):16 for m in self._modules:17 normal_init(self._modules[m], mean, std)18 19 def forward(self, input):20 x = F.relu(self.deconv1_bn(self.deconv1(input)))21 x = F.relu(self.deconv2_bn(self.deconv2(x)))22 x = F.relu(self.deconv3_bn(self.deconv3(x)))23 x = F.relu(self.deconv4_bn(self.deconv4(x)))24 x = torch.tanh(self.deconv5(x))25 26 return x 27

調べたこと

学習済みモデルと現在のモデルのアーキテクチャが異なるため、このようなエラーが出ていることはわかりました。
しかし、学習した時と現在のアーキテクチャは何も変えてません。
なぜパラメータが一致しないのか不明です。

補足情報(FW/ツールのバージョンなど)

python 3.9
torch 1.13.1
torchvision 0.14.1

コメントを投稿

0 コメント