pytorchで関数への入力を複数処理したい

pytorchでclass内で定義した関数に対して複数の入力に対応した出力をしたいです。
下のプログラムは、入力x=[1,1,1]に対して出力を一つしています。

code

1import torch 2import torch.nn as nn 3import numpy as np 4import torch.optim as optim 5import copy 6import matplotlib.pyplot as plt 7 8class einsum_test(nn.Module): 9 def __init__(self): 10 super().__init__() 11 12 rnd = np.random.randn 13 self.A1 = nn.Parameter(torch.Tensor(rnd(2, 2))) 14 self.A2 = nn.Parameter(torch.Tensor(rnd(2,2,2))) 15 self.A3 = nn.Parameter(torch.Tensor(rnd(2,2))) 16 17 self.optimizer = optim.AdamW(self.parameters(), lr=0.01) 18 19 def f(self, x): 20 21 I0 = torch.tensor([1, 0], dtype=torch.float32) 22 I1 = torch.tensor([0, 1], dtype=torch.float32) 23 #縮約 24 I = I0 if x[0] == 0 else I1 25 p = torch.einsum("ia,i->a",self.A1,I) 26 I = I0 if x[1] == 0 else I1 27 p = torch.einsum("a,aib,i->b",p,self.A2,I) 28 I = I0 if x[1] == 0 else I1 29 p = torch.einsum("a,ai,i->...",p,self.A3,I) 30 31 return p 32 33model = einsum_test() 34x = torch.tensor([1,1,1]) 35y = model.f(x) 36print(y)

x=[1,1,1]と[0,1,0]と[0,0,1]の3つの入力データに対して、3つのスカラ値の出力を得たいのですが、
Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
のエラーが出て自動微分を有効にすると多次元処理ができません。
下記プログラムを処理する方法はありますか?

code

1import torch 2import torch.nn as nn 3import numpy as np 4import torch.optim as optim 5import copy 6import matplotlib.pyplot as plt 7 8class einsum_test(nn.Module): 9 def __init__(self): 10 super().__init__() 11 12 rnd = np.random.randn 13 self.A1 = nn.Parameter(torch.Tensor(rnd(2, 2))) 14 self.A2 = nn.Parameter(torch.Tensor(rnd(2,2,2))) 15 self.A3 = nn.Parameter(torch.Tensor(rnd(2,2))) 16 17 self.optimizer = optim.AdamW(self.parameters(), lr=0.01) 18 19 def f(self, x): 20 21 I0 = torch.tensor([1, 0], dtype=torch.float32) 22 I1 = torch.tensor([0, 1], dtype=torch.float32) 23 #縮約 24 I = I0 if x[0] == 0 else I1 25 p = torch.einsum("ia,i->a",self.A1,I) 26 I = I0 if x[1] == 0 else I1 27 p = torch.einsum("a,aib,i->b",p,self.A2,I) 28 I = I0 if x[1] == 0 else I1 29 p = torch.einsum("a,ai,i->...",p,self.A3,I) 30 31 return p 32 33model = einsum_test() 34x = torch.tensor([[1,1,1],[0,1,0],[0,0,1]]) 35y = np.vectorize(model.f, signature='(n)->()')(x) 36print(y)

コメントを投稿

0 コメント