Attention Weightヒートマップが期待していたものと違う

train.py

1import torch 2import torch.optim as optim 3from torch.utils.data import DataLoader 4from learner import Learner 5from loss import MIL 6from dataset import Normal_Loader, Anomaly_Loader 7import os 8from sklearn import metrics 9import argparse 10import matplotlib.pyplot as plt 11import seaborn as sns 12import numpy as np 13from openpyxl import Workbook 14 15# コマンドライン引数をパース 16parser = argparse.ArgumentParser(description='PyTorch MIL Training') 17parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 18parser.add_argument('--w', default=0.0010000000474974513, type=float, help='weight_decay') 19parser.add_argument('--modality', default='RGB', type=str, help='modality') 20parser.add_argument('--input_dim', default=1024, type=int, help='input_dim') 21parser.add_argument('--drop', default=0.3, type=float, help='dropout_rate') 22parser.add_argument('--seed', default=8111, type=int, help='random seed') 23parser.add_argument('--save_dir', default='./loss', type=str, help='directory to save the Excel file') 24parser.add_argument('--save_attention_epoch', default=35, type=int, help='epoch to save attention weights') 25args = parser.parse_args() 26 27# ベストのAUCを記録する変数 28best_auc = 0 29 30# 正常データと異常データのデータセットとデータローダーを準備 31normal_train_dataset = Normal_Loader(is_train=1, modality=args.modality) 32normal_test_dataset = Normal_Loader(is_train=0, modality=args.modality) 33anomaly_train_dataset = Anomaly_Loader(is_train=1, modality=args.modality) 34anomaly_test_dataset = Anomaly_Loader(is_train=0, modality=args.modality) 35 36normal_train_loader = DataLoader(normal_train_dataset, batch_size=30, shuffle=False) 37normal_test_loader = DataLoader(normal_test_dataset, batch_size=1, shuffle=True) 38anomaly_train_loader = DataLoader(anomaly_train_dataset, batch_size=30, shuffle=False) 39anomaly_test_loader = DataLoader(anomaly_test_dataset, batch_size=1, shuffle=True) 40 41# GPUが利用可能な場合はGPUを使用し、そうでなければCPUを使用 42device = 'cuda' if torch.cuda.is_available() else 'cpu' 43 44# モデルを定義し、指定したデバイスに送る 45model = Learner(input_dim=args.input_dim).to(device) 46 47# オプティマイザとスケジューラーを設定 48optimizer = torch.optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.w) 49scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50]) 50 51# 損失関数を設定 52criterion = MIL 53 54# ランダムシードを設定 55seed = args.seed 56torch.manual_seed(seed) 57np.random.seed(seed) 58 59# エポックごとの損失値を保存するためのリスト 60loss_values = [] 61 62# Excelワークブックとシートの作成 63wb = Workbook() 64ws = wb.active 65ws.title = "Loss Values" 66ws.append(['Epoch', 'Loss', 'Sparsity', 'Smooth', 'MIL Loss', 'Scores']) # ヘッダー行の追加 67 68# トレーニング関数 69def train(epoch): 70 print(f'\nEpoch: {epoch}') 71 model.train() 72 train_loss = 0 73 epoch_loss = 0 74 epoch_sparsity = 0 75 epoch_smooth = 0 76 epoch_mil_loss = 0 77 78 epoch_scores = [] # 各エポックのスコアを保存するリスト 79 80 for i, (normal_inputs, anomaly_inputs) in enumerate(zip(normal_train_loader, anomaly_train_loader)): 81 unique_id = i 82 inputs = torch.cat([anomaly_inputs, normal_inputs], dim=1) # 異常データと正常データをセグメントについて連結 83 batch_size = inputs.shape[0] 84 inputs = inputs.to(device) 85 outputs, attention_weights = model(inputs) 86 # print(outputs.shape) 87 loss = criterion(outputs, batch_size) # MIL関数を使用して損失を計算 88 optimizer.zero_grad() 89 loss.backward() 90 optimizer.step() 91 train_loss += loss.item() 92 93 epoch_scores.extend(outputs.cpu().detach().numpy()) # スコアをリストに追加 94 95 if epoch == args.save_attention_epoch: # エポックが指定された値の場合 96 attention_dir = os.path.join(args.save_dir, 'attention_weights') 97 if not os.path.exists(attention_dir): 98 os.makedirs(attention_dir) 99 100 # 異常データのAttention Weightを保存 101 for j in range(len(anomaly_inputs)): 102 anomaly_attention_map = attention_weights[:, :32][j].cpu().detach().numpy() # 前半部分が異常データのAttention Weight 103 anomaly_attention_map = anomaly_attention_map.transpose() 104 plt.figure(figsize=(8, 8)) 105 sns.heatmap(anomaly_attention_map, cmap="hot", annot=False, vmax=0.6, vmin=0.0) 106 plt.savefig(os.path.join(attention_dir, f'anomaly_sample_{unique_id * len(anomaly_inputs) + j}.png')) 107 plt.close() 108 109 avg_loss = train_loss / len(normal_train_loader) 110 loss_values.append([epoch, avg_loss, 0, 0, avg_loss]) # 損失値をリストに追加(sparsity, smoothは0に設定) 111 ws.append([epoch, avg_loss, 0, 0, avg_loss, str(epoch_scores)]) # スコアをExcelに追加 112 113 print('Train Loss:', avg_loss) 114 scheduler.step() 115 116 # 損失値をExcelに保存 117 if not os.path.exists(args.save_dir): 118 os.makedirs(args.save_dir) 119 save_path = os.path.join(args.save_dir, 'loss_values.xlsx') 120 wb.save(save_path) 121 print(f'Loss values saved to {save_path}')

コメントを投稿

0 コメント