pythonで関数への入力を複数処理する方法を教えてください

pythonでnumpyのeinsumを使用するときに9個の配列を1つのデータとしてを一つずつ処理しているのですが、
これを複数入力データに対応できるようにバッチ処理を行いたいのですが、どのようにすればよいでしょうか。

バッチ処理していないeinsum計算が下記です。

python

1import numpy as np 2 3def einsum_test(x):4 A1 = np.arange(4).reshape(2,2)5 A2 = np.arange(8).reshape(2,2,2)6 A3 = np.arange(4).reshape(2,2)7 8 I0 = torch.tensor([1, 0])9 I1 = torch.tensor([0, 1])10 11 I = I0 if x[0] == 0 else I1 12 y = np.einsum("ia,i->a",A1,I)13 I = I0 if x[1] == 0 else I1 14 y = np.einsum("a,aib,i->b",y,A2,I)15 I = I0 if x[2] == 0 else I1 16 y = np.einsum("a,ai,i->...",y,A3,I)17 18 return y 19 20x = np.array([1,1,1])21y = einsum_test(x)22print(y)

output

1103

この処理を複数の入力で処理したいです。
xを多次元配列にして、これを入力としたいのですが、
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
とエラーが出てしまいます。

x

1x = np.array([[1,1,1],[0,1,0],[0,0,1]])

python

1import numpy as np 2 3def einsum_test_batch(x):4 A1 = np.arange(4).reshape(2,2)5 A2 = np.arange(8).reshape(2,2,2)6 A3 = np.arange(4).reshape(2,2)7 8 I0 = torch.tensor([1, 0])9 I1 = torch.tensor([0, 1])10 11 I = I0 if x[:,0] == 0 else I1 12 y = np.einsum("ia,i->a",A1,I)13 I = I0 if x[:,1] == 0 else I1 14 y = np.einsum("a,aib,i->b",y,A2,I)15 I = I0 if x[:,2] == 0 else I1 16 y = np.einsum("a,ai,i->...",y,A3,I)17 18 return y 19 20x = np.array([[1,1,1],[0,1,0],[0,0,1]])21y = einsum_test_batch(x)22print(y)

このプログラムを修正して下記の出力となるようにできますか?

expect

1103 214 319

コメントを投稿

0 コメント