pytorchで2xNのテンソルに対し, 1行目をグループのインデックスとしたとき0行目の最小値を求めたい

実現したいこと

pytorchで2xNのテンソルに対し, 1行目をグループのインデックスとしたとき各グループにおける0行目の最小値を求めたいです.

例えば,
tensor([[3, 4, 1, 3, 5, 2, 2, 1, 1],[1, 1, 2, 2, 2, 3, 3, 4, 5]])というテンソルがあったとして, 1行目がグループを表し, 0行目がデータを表します.
このとき, 例えばグループ1のデータは1行目が1の0, 1列目のデータで[3,4]なので最小値は3となり出力は3となります. これをグループ1~5まで行い出力は[3, 1, 2, 1, 1]となります.

ナイーブに行うと0行目のtensorを1行目のグループに従って最小値をとるやり方でできますがこれに近い動作を行うpytorchの組み込み関数をご存知のかたいらっしゃらないでしょうか?

補足情報(FW/ツールのバージョンなど)

pytorch 2.0.0
python 3.7.13

コメントを投稿

0 コメント