pytorchで複数入力の配列を処理したい

python

1import torch 2 3x = torch.tensor([[0., 0., 1., 1., 0.],4 [1., 2., 4., 3., 0.],5 [2., 1., 1., 2., 2.],6 [4., 3., 4., 0., 1.],7 [3., 1., 0., 1., 3.]])8 9I0 = torch.tensor([[1,0,0,0,0]], dtype=torch.float)10I1 = torch.tensor([[0,1,0,0,0]], dtype=torch.float)11I2 = torch.tensor([[0,0,1,0,0]], dtype=torch.float)12I3 = torch.tensor([[0,0,0,1,0]], dtype=torch.float)13I4 = torch.tensor([[0,0,0,0,1]], dtype=torch.float)14# concatenate Ix tensors and select rows15I = torch.cat((I0, I1, I2, I3, I4))[x[:,0].to(int)]16 17print(I)18 19# tensor([[1., 0., 0., 0., 0.],20# [0., 1., 0., 0., 0.],21# [0., 0., 1., 0., 0.],22# [0., 0., 0., 0., 1.],23# [0., 0., 0., 1., 0.]])

コメントを投稿

0 コメント