anoGANにおける異常検知のノイズの学習について

python

import torch.nn as nn class Discriminator_from_books(nn.Module): def __init__(self): super(Discriminator_from_books, self).__init__() z_dim = 100 in_ch = 1 image_size = 128 self.layer1 = nn.Sequential( nn.Conv2d( in_ch, image_size, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)) self.layer2 = nn.Sequential( nn.Conv2d( image_size, image_size*2, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(image_size * 2), nn.LeakyReLU(0.2, inplace=True)) self.layer3 = nn.Sequential( nn.Conv2d( image_size*2, image_size*4, kernel_size=3, stride=2, padding=0), nn.BatchNorm2d(image_size * 4), nn.LeakyReLU(0.2, inplace=True)) self.layer4 = nn.Sequential( nn.Conv2d( image_size*4, 1, kernel_size=3, stride=1, padding=0), nn.Sigmoid()) def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) feature = out feature = feature.view(feature.size()[0], -1) out = self.layer4(out) return out.squeeze(), feature class Generator_from_books(nn.Module): def __init__(self,): super(Generator_from_books, self).__init__() z_dim = 100 image_size = 128 img_ch = 1 self.layer1 = nn.Sequential( nn.ConvTranspose2d( z_dim, image_size * 4, kernel_size=3, stride=1), nn.BatchNorm2d(image_size * 4), nn.ReLU(inplace=True)) self.layer2 = nn.Sequential( nn.ConvTranspose2d( image_size * 4, image_size * 2, kernel_size=3, stride=2, padding=0), nn.BatchNorm2d(image_size * 2), nn.ReLU(inplace=True)) self.layer3 = nn.Sequential( nn.ConvTranspose2d( image_size * 2, image_size, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(image_size), nn.ReLU(inplace=True)) self.layer4 = nn.Sequential( nn.ConvTranspose2d( image_size, img_ch, kernel_size=4, stride=2, padding=1), nn.Tanh()) def forward(self, z): out = self.layer1(z) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) return out def weights_init(m): classname = m.__class__.__name__ if classname.find('conv') !=-1: m.weight.data.normal_(0.0,0.02) m.bias.data.fill_(0) elif classname.find('BatchNorm') !=-1: m.weight.data.normal_(0.0,0.02) m.bias.data.fill_(0) import torchsummary generator = Generator_from_books().to(device)generator.apply(weights_init)torchsummary.summary( generator, (100,1,1) ) discriminator = Discriminator_from_books().to(device)discriminator.apply(weights_init)torchsummary.summary( discriminator, (1,28,28) ) import torch.optim as optim criterion = nn.BCELoss() optimizer_ds = optim.Adam( discriminator.parameters(), lr = 0.0002, betas = (0.5,0.999) )optimizer_gn = optim.Adam( generator.parameters(), lr = 0.0002, betas = (0.5,0.999)) from dis import dis from locale import normalize from pickle import TRUE from tkinter.tix import Tree import torchvision.utils as vutils n_epoch = 10 gn_input_dim = 100 outf = 'C:\\micro\\boy\\\\and\\macro\\girl\\image_of_gan' fixed_noise = torch.randn( batch_size,gn_input_dim,1,1,device=device ) for epoch in range(n_epoch): print('Epoch{}/{}'.format(epoch+1,n_epoch)) for itr, data in enumerate(dataloader): real_image = data[0].to(device) sample_size = real_image.size(0) noise = torch.randn( sample_size, gn_input_dim, 1, 1, device=device ) real_target = torch.full( (sample_size,), 1., device=device) fake_target = torch.full( (sample_size,), 0., device=device) discriminator.zero_grad() output,_ = discriminator(real_image) ds_real_err = criterion( output, real_target ) true_dsout_mean = output.mean().item() fake_image = generator(noise) output,_ = discriminator.forward(fake_image.detach()) ds_feke_err = criterion( output, fake_target ) fake_dsout_mean1 = output.mean().item() ds_err = ds_real_err + ds_feke_err ds_err.backward() optimizer_ds.step() generator.zero_grad() output,_ = discriminator.forward(fake_image) gn_err = criterion( output, real_target ) gn_err.backward() fake_dsout_mean2 = output.mean().item() optimizer_gn.step() if itr % 100 == 0: print('({}/{}) ds_loss: {:.3f} - gn_loss: {:.3f} - true_out: {:.3f} - fake_out: {:.3f} >> : {:.3f}' .format( itr + 100, len(dataloader), ds_err.item(), gn_err.item(), true_dsout_mean, fake_dsout_mean1, fake_dsout_mean2 )) if epoch == 0 and itr == 0: vutils.save_image(real_image, '{}/real_samples.png'.format(outf), notmalize = True, nrow = 10) fake_image = generator(fixed_noise) vutils.save_image( fake_image.detach(), '{}/generated_epoch_{:03d}.png'.format(outf,epoch +1), normalize = True, nrow = 10) import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,)) ]) f_mnist_test = torchvision.datasets.FashionMNIST( root = 'C:\\Users\\lost\\in\\the\\rhythm\\at\\night', download = True, train = False, transform = transform ) ano_batch_size = 1anoloader = DataLoader( f_mnist_test, batch_size = ano_batch_size, shuffle = True) def Anomaly_score(x, fake_img, D, Lambda=0.1): residual_loss1 = torch.abs(x-fake_img) residual_loss2 = residual_loss1.view(residual_loss1.size()[0], -1) residual_loss3 = torch.sum(residual_loss2, dim=1) _, x_feature = D(x) _, G_feature = D(fake_img) discrimination_loss = torch.abs(x_feature-G_feature) discrimination_loss2 = torch.sum(discrimination_loss, dim=1) loss_each = (1-Lambda)*residual_loss3 + Lambda*discrimination_loss2 total_loss = torch.sum(loss_each) return total_loss, loss_each, residual_loss3 import torch from torchviz import make_dot import matplotlib.pyplot as plt device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')ram = 0.1 for itr, data in enumerate(anoloader): ano_image = data[0].to(device) ano_size = ano_image.size(0) print(ano_size) z = torch.randn( ano_size, gn_input_dim, 1, 1, device = device ).to(device) z.requiers_grad = True z_optimizer = torch.optim.Adam([z],lr = 0.003) fake_ano_image = generator(z) np_fake_ano_image = fake_ano_image[0][0].to('cpu').detach().numpy().copy() plt.imshow(np_fake_ano_image, cmap = 'binary_r') plt.xlabel("loss") plt.show() for epoch in range(2000+1): loss, _, _ = Anomaly_score(ano_image, fake_ano_image, discriminator, Lambda=0.1) z_optimizer.zero_grad() loss.backward(retain_graph=True) z_optimizer.step() if epoch % 100 == 0: print('one of noise',z[0][0]) if epoch % 1000 == 0: print('epoch{} || loss_total:{:.0f} '.format(epoch,loss.item())) if epoch == 2000: break break

コメントを投稿

0 コメント