テンソル変換された画像が出力されない

前提

pytorch実践入門という本の7章の画像分類モデルの構築を行っております。
鳥の画像をテンソル変換して出力したいのですが出力されません。どうしてでしょうか?よろしくお願いいたします。

やってほしいこと


なぜ画像が出力されずエラーとなるのかを教えていただきたいです。また、対処法として画像が出力されるためにはどんなコードを書けば良いのか教えていただきたいです。

発生している問題・エラーメッセージ


AttributeError Traceback (most recent call last)
<ipython-input-405-50be3884ac1f> in <module>
1 img, _ = cifar2[0]
2
----> 3 plt.imshow(img.permute(1, 2, 0))
4 plt.show()

/opt/anaconda3/lib/python3.8/site-packages/PIL/Image.py in getattr(self, name)
539 )
540 return self._category
--> 541 raise AttributeError(name)
542
543 @property

AttributeError: permute

該当のソースコード

%matplotlib inline from matplotlib import pyplot as plt import numpy as np import torch torch.set_printoptions(edgeitems=2) torch.manual_seed(123) import torch import torchvision cifar10_data = torchvision.datasets. CIFAR10( root='./cifar-10', train=True, download=True, transform=torchvision.transforms.ToTensor()) cifar10 = torchvision.datasets.CIFAR10(root='./cifar-10', train=True, download=True) # <1> cifar10_val = torchvision.datasets.CIFAR10(root='./cifar-10', train=False, download=True) # <2> type(cifar10).__mro__ len(cifar10) class_names = ['airplane','automobile','bird','cat','deer', 'dog','frog','horse','ship','truck'] fig = plt.figure(figsize=(8,3)) num_classes = 10 for i in range(num_classes): ax = fig.add_subplot(2, 5, 1 + i, xticks=[], yticks=[]) ax.set_title(class_names[i]) img = next(img for img, label in cifar10 if label == i) plt.imshow(img) plt.show() img, label = cifar10[99] img, label, class_names[label] plt.imshow(img) plt.show() from torchvision import transforms dir(transforms) to_tensor = transforms.ToTensor() img_t = to_tensor(img) img_t.shape tensor_cifar10 = torchvision.datasets.CIFAR10(root='./cifar-10', train=True, download=False, transform=transforms.ToTensor()) img_t, _ = tensor_cifar10[99] type(img_t) img_t.shape, img_t.dtype img_t.min(), img_t.max() plt.imshow(img_t.permute(1, 2, 0)) plt.show() imgs = torch.stack([img_t for img_t, _ in tensor_cifar10], dim=3) imgs.shape imgs.view(3, -1).mean(dim=1) imgs.view(3, -1).std(dim=1) transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) transformed_cifar10 = torchvision.datasets.CIFAR10(root='./cifar-10', train=True, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))])) img_t, _ = transformed_cifar10[99] plt.imshow(img_t.permute(1, 2, 0)) plt.show() label_map = {0: 0, 2: 1} class_names = ['airplane', 'bird'] cifar2 = [(img, label_map[label]) for img, label in cifar10 if label in [0, 2]] cifar2_val = [(img, label_map[label]) for img, label in cifar10_val if label in [0, 2]] import torch.nn as nn n_out = 2 model = nn.Sequential( nn.Linear( 3072, # <1> 512, # <2> ), nn.Tanh(), nn.Linear( 512, # <2> n_out, # <3> ) ) def softmax(x): return torch.exp(x) / torch.exp(x).sum() softmax(x).sum() x = torch.tensor([1.0, 2.0, 3.0]) softmax(x) softmax = nn.Softmax(dim=1) x = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) softmax(x) model = nn.Sequential( nn.Linear(3072, 512), nn.Tanh(), nn.Linear(512, 2), nn.Softmax(dim=1)) img, _ = cifar2[0] plt.imshow(img.permute(1, 2, 0)) plt.show()

コメントを投稿

0 コメント