敵対的学習を用いたドメイン適応のコーディングが正しいかどうか教えてください

私はPytorch geometricを用いてグラフの2値分類を行っています。
今、ノード・エッジの特徴量・ラベルについて同じ構造を持つデータセットA(1000程度のグラフで構成される)とデータセットB(規模的にはAと同じ)について、個々にPyGで学習すると高いパフォーマンス(0.9以上のaccuracy/F1)が得られるものの、学習器AでデータセットBを予測すると、すべて'0'と予測し、逆を行うとすべて'1'と予測するような状況であり、ドメインが異なっていると想像しています。特徴量は、AもBも同じ数値をもとに正規化してあり、その分布のヒストグラムをつくってみると、AとBの間で、分布が重なっていない特徴量が多いことがわかりました。
このような場合に、AにもBにも高いパフォーマンスを示すような学習器を作りたいと考えています。(最終目的は、同じ構造を持つデータセットCに対しても高いパフォーマンスを持つようにしたいのですが)
データセットAとデータセットBについては、両方ともにラベルを持っています。
そこで下記にコードを示すような敵対的学習を用いて、データセットAとデータセットBについてドメイン適応を行ってみましたが、学習が進まない結果となりました。lrを0.1/0.01/0.001と変えてみたが学習が進まない状況は同じでした。

#教えていただきたいこと
正直なところ、まだ初学者で、敵対的学習によるドメイン適応というものをしっかり理解していないため、このコードに自信が持てません。このコードの敵対的学習のコーディングについて、根本的な誤り・誤解がないかを教えていただきたいのです。自分としては、敵対的学習を理解して、将来的に特徴量の構造は同じですが、ラベルのないデータセットCに対応できるような敵対的学習のプログラムを作ることが最終目標です。

ちなみに普通の転移学習とファインチューニングは試みましたが、よいパフォーマンスは得られませんでした。(データセットAとデータセットBについて)
また、データセットAとデータセットBを混合させてデータセットを作り、学習させたテスト結果はは、そこそこのパフォーマンスを示しました。

(datasetを作る部分を省略しています) class FeatureExtractor(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super(FeatureExtractor, self).__init__() self.conv1 = GENConv(in_channels, hidden_channels) self.fc = nn.Linear(hidden_channels, out_channels) self.dropout = nn.Dropout(p=0.5) # ドロップアウト:(p=ドロップアウト率) def forward(self, data): x = data.x.to(device) # xをデバイスに移動 edge_index = data.edge_index.to(device) # edge_indexをデバイスに移動 batch = data.batch.to(device) if 'batch' in data else None # batchが存在する場合、デバイスに移動 x = self.conv1(x, edge_index) # 全てのノードで各特徴量の平均をとる x = global_mean_pool(x, batch) # (バッチサイズ, 特徴量の数)に変換 x = self.dropout(x) x = self.fc(x) return x class Classifier(nn.Module): def __init__(self, out_channels): super(Classifier, self).__init__() self.fc = nn.Linear(out_channels, 2) # 2値分類 def forward(self, x): x = self.fc(x) return x class Discriminator(nn.Module): def __init__(self, out_channels): super(Discriminator, self).__init__() self.fc = nn.Linear(out_channels, 2) # 2ドメイン def forward(self, x): x = self.fc(x) return torch.sigmoid(x) # ロス計算関数(仮定) def compute_classification_loss(output, target): return nn.CrossEntropyLoss()(output, target) def compute_discriminator_loss(output_d_a, output_d_b): # ディスクリミネーターの損失計算 # ドメインAのデータを0, ドメインBのデータを1として識別できるように学習 label_a = torch.zeros(output_d_a.size(0), dtype=torch.long, device=device) label_b = torch.ones(output_d_b.size(0), dtype=torch.long, device=device) loss_a = nn.CrossEntropyLoss()(output_d_a, label_a) loss_b = nn.CrossEntropyLoss()(output_d_b, label_b) return loss_a + loss_b def compute_feature_extractor_loss(output_d_a, output_d_b): # 特徴抽出器の損失計算 # ディスクリミネーターを騙すことを目指す batch_size_a = output_d_a.size(0) batch_size_b = output_d_b.size(0) label_a = torch.ones(batch_size_a, dtype=torch.long, device=device) # ドメインBであると騙す label_b = torch.zeros(batch_size_b, dtype=torch.long, device=device) # ドメインAであると騙す loss_a = nn.CrossEntropyLoss()(output_d_a, label_a) # データセットAをBと騙すため、逆にする loss_b = nn.CrossEntropyLoss()(output_d_b, label_b) # データセットBをAと騙すため、逆にする return loss_a + loss_b in_channels = dataset_a.num_node_features hidden_channels = 64 out_channels = dataset_a.num_classes num_epochs = 100 # モデルの初期化 feature_extractor = FeatureExtractor(in_channels, hidden_channels, out_channels).to(device) classifier = Classifier(out_channels).to(device) discriminator = Discriminator(out_channels).to(device) # 最適化関数の設定 optimizer_fe = optim.Adam(feature_extractor.parameters(), lr=0.001) optimizer_c = optim.Adam(classifier.parameters(), lr=0.001) optimizer_d = optim.Adam(discriminator.parameters(), lr=0.001) # 学習過程を記録するためのリストを用意 train_losses = [] valid_losses = [] train_accuracies = [] valid_accuracies = [] for epoch in range(num_epochs): print('epoch:', epoch) feature_extractor.train() classifier.train() discriminator.train() epoch_train_loss = 0 correct_train = 0 total_train = 0 correct_train_a = 0 total_train_a = 0 correct_train_b = 0 total_train_b = 0 # データセットAとBのデータローダーからデータを同時に取得 for data_a, data_b in zip(train_loader_a, train_loader_b): data_a = data_a.to(device) data_b = data_b.to(device) # データセットAの特徴抽出 features_a = feature_extractor(data_a) # データセットBの特徴抽出 features_b = feature_extractor(data_b) # 分類器の訓練(データセットAとBの両方に対して行う) optimizer_c.zero_grad() output_c_a = classifier(features_a) # データセットAの特徴に基づく出力 output_c_b = classifier(features_b) # データセットBの特徴に基づく出力 labels_a = data_a.y.to(device) labels_b = data_b.y.to(device) loss_c_a = compute_classification_loss(output_c_a, labels_a) # データセットAに対する損失 loss_c_b = compute_classification_loss(output_c_b, labels_b) # データセットBに対する損失 loss_c = (loss_c_a + loss_c_b) / 2 # 平均損失を計算 loss_c.backward(retain_graph=True) # バックプロパゲーション optimizer_c.step() # 分類器による予測精度の計算(データセットA) _, predicted_a = torch.max(output_c_a, 1) total_train_a += data_a.y.size(0) correct_train_a += (predicted_a == data_a.y).sum().item() # 分類器による予測精度の計算(データセットB) _, predicted_b = torch.max(output_c_b, 1) total_train_b += data_b.y.size(0) correct_train_b += (predicted_b == data_b.y).sum().item() total_train = total_train_a + total_train_b correct_train = correct_train_a + correct_train_b # 敵対的学習(ディスクリミネーター) optimizer_d.zero_grad() output_d_a = discriminator(features_a.detach()) output_d_b = discriminator(features_b.detach()) loss_d = compute_discriminator_loss(output_d_a, output_d_b) loss_d.backward(retain_graph=True) optimizer_d.step() # 特徴抽出器の更新(ドメイン適応) optimizer_fe.zero_grad() # ディスクリミネーターを騙すための損失を計算 output_d_a = discriminator(features_a) output_d_b = discriminator(features_b) loss_fe = compute_feature_extractor_loss(output_d_a, output_d_b) loss_fe.backward() optimizer_fe.step() epoch_train_loss += loss_c.item() + loss_d.item() + loss_fe.item() # 精度と損失の記録 train_accuracy = 100 * correct_train / total_train train_losses.append(epoch_train_loss / (len(train_loader_a) + len(train_loader_b))) # 両データセットの平均損失 train_accuracies.append(train_accuracy) #print('train accuracy:', train_accuracy) print('train loss:', train_losses[-1], end='') # 訓練後の検証フェーズ feature_extractor.eval() classifier.eval() y_true = [] y_pred = [] epoch_valid_loss = 0 total_valid = 0 correct_valid = 0 with torch.no_grad(): for data in valid_loader_a: data = data.to(device) # データ全体をデバイスに移動 labels = data.y.to(device) # ラベルだけ別途デバイスに移動 features = feature_extractor(data) output = classifier(features) loss = compute_classification_loss(output, labels) epoch_valid_loss += loss.item() pred = output.argmax(dim=1) y_true.extend(data.y.tolist()) y_pred.extend(pred.tolist()) correct_valid += (pred == data.y).sum().item() total_valid += data.y.size(0) valid_losses.append(epoch_valid_loss / len(valid_loader_a)) valid_accuracy = 100 * correct_valid / total_valid valid_accuracies.append(valid_accuracy) valid_f1 = f1_score(y_true, y_pred, average='macro') #print('valid accuracy:', valid_accuracy, 'valid f1:', valid_f1) print('valid loss:', valid_losses[-1]) print(epoch, f'train_acc:{train_accuracy}',f'valid_acc:{valid_accuracy}', f'valid_f1:{valid_f1:.4f}') # モデルの最も良い重みを読み込む checkpoint = torch.load('best_model.pt') # 各モデルの状態辞書を復元 feature_extractor.load_state_dict(checkpoint['feature_extractor_state_dict'], strict=True) classifier.load_state_dict(checkpoint['classifier_state_dict'], strict=True) discriminator.load_state_dict(checkpoint['discriminator_state_dict'], strict=True) # 訓練後のモデル評価 (修正案) def evaluate(feature_extractor, classifier, loader): feature_extractor.eval() classifier.eval() y_true = [] y_pred = [] with torch.no_grad(): for data in loader: data = data.to(device) # データをGPUに移動 features = feature_extractor(data) outputs = classifier(features) _, predicted = torch.max(outputs, 1) y_true.extend(data.y.cpu().numpy()) y_pred.extend(predicted.cpu().numpy()) accuracy = accuracy_score(y_true, y_pred) f1 = f1_score(y_true, y_pred, average='macro') # 実際のラベルと予測されたラベルから混同行列を計算 cm = confusion_matrix(y_true, y_pred) return accuracy, f1, cm # モデルの評価結果の出力 train_accuracy_a, train_f1_a, train_cm_a = evaluate(feature_extractor, classifier, train_loader_a) test_accuracy_a, test_f1_a, test_cm_a = evaluate(feature_extractor, classifier, test_loader_a) train_accuracy_b, train_f1_b, train_cm_b = evaluate(feature_extractor, classifier, train_loader_b) test_accuracy_b, test_f1_b, test_cm_b = evaluate(feature_extractor, classifier, test_loader_b)

コメントを投稿

0 コメント