Google Colaboratoryでのメモリクラッシュの対処について(圧縮センシング)

実現したいこと

Google Colaboratoryにおいて圧縮センシングを用いた画像復元をしたいと考えており、チャットGPTにてコードを作成し、走らせたところ、課金後のシステムRAM83.5 GBがクラッシュしてしまいましたので解決したいです。

発生している問題・分からないこと

使用可能な RAM をすべて使用した後で、セッションがクラッシュしました。

エラーメッセージ

error

1使用可能な RAM をすべて使用した後で、セッションがクラッシュしました。

該当のソースコード

Python

1import numpy as np 2import imageio 3import torch 4from google.colab import drive 5 6# Google Driveをマウント7drive.mount('/content/drive')8 9# 画像のパスを適切に指定10x_orig = imageio.v2.imread("/content/drive/My Drive/Colab Notebooks/006.bmp", pilmode='L') # グレースケールで読み込み11ny, nx = x_orig.shape 12 13k = round(nx * ny * 0.5)14ri = np.random.choice(nx * ny, k, replace=False)15y = x_orig.T.flat[ri]16 17# 総変動正則化を計算する関数18def total_variation(image):19 dx = torch.abs(image[:, 1:] - image[:, :-1])20 dy = torch.abs(image[1:, :] - image[:-1, :])21 return torch.sum(dx) + torch.sum(dy)22 23# 損失関数の定義24def loss_function_tv(recon, target, vx, alpha=0.1, beta=0.1):25 tv_loss = total_variation(vx)26 mse_loss = torch.mean((recon - target) ** 2)27 return mse_loss + alpha * tv_loss + beta * torch.norm(vx, 1)28 29# メモリ容量を確認し、64GB以上であればGPUを使用30device = torch.device("cuda" if torch.cuda.is_available() else "cpu")31 32batch_size = 8 # バッチサイズを小さく調整33num_batches = len(y) // batch_size 34 35vx = torch.nn.Parameter(torch.randn(nx * ny, requires_grad=True, dtype=torch.float32).to(device))36optimizer = torch.optim.Adam([vx], lr=0.01)37 38y = torch.tensor(y, dtype=torch.float32).to(device)39 40# yの次元を確認して適切な形状に変換する41y = y.reshape(-1, 1) # 1列に変換42 43# thetaの形状を確認し、適切に定義する44theta = torch.randn(num_batches, batch_size, nx * ny, dtype=torch.float32).to(device)45 46for epoch in range(1000):47 for i in range(num_batches):48 optimizer.zero_grad()49 50 y_batch = y[i * batch_size: (i + 1) * batch_size].reshape(-1, 1) # 1列に変換51 theta_batch = theta[i].reshape(batch_size, nx * ny) # thetaを適切な形状に変換52 53 recon = torch.matmul(theta_batch, vx)54 55 loss = loss_function_tv(recon, y_batch, vx, alpha=0.1, beta=0.1)56 loss.backward()57 optimizer.step()58 59# 画像の復元60with torch.no_grad():61 x_recovered = torch.matmul(theta.view(-1, nx * ny), vx).cpu().numpy().reshape(ny, nx).T 62 x_recovered_final = x_recovered.astype('uint8')63 64# 画像の保存先を指定65imageio.imwrite('./256_ver510.bmp', x_recovered_final)

試したこと・調べたこと

上記の詳細・結果

Google Colaboratoryでのメモリクラッシュについてgoogleで検索しましたが正直よくわかりません。

補足

以前Deep Learningを少ししていた際はbitmap画像1枚でデスクトップPCのメモリがフルで使用された記憶はないので、コードの中身が悪いのだとは思うのですが、、、

コメントを投稿

0 コメント