UCF用のデータローダをCityscapes用に書き換えたい

実現したいこと

・Cityscapesデータセットのデータローダを作りたい

前提

ここに質問の内容を詳しく書いてください。
(例)
TypeScriptで●●なシステムを作っています。
■■な機能を実装中に以下のエラーメッセージが発生しました。

発生している問題・エラーメッセージ

UCF用のデータローダをCityscapes用に書き換えたい
・256*256へのリサイズ
・テンソルの次元の変更

該当のソースコード

Python

1import os.path 2from PIL import Image 3import torchvision.transforms as transforms 4from torch.utils.data import Dataset 5import torch 6import random 7import math 8import glob 9from pathlib import Path 10from skimage.io import imread 11 12# バッチ単位で乱数をふる13 14 15class AlignedDataset(Dataset):16 def __init__(self, config, purpose) -> None:17 # データセットクラスの初期化18 self.config = config 19 self.num_frames = config.num_frames # 16フレーム20 self.num_intervals = config.num_intervals # 間隔521 self.class_idx_dict = self.classToIdx(config)22 23 # imageファイルのリスト24 self.image_list = []25 # labelsファイルのリスト26 self.labels_list = []27 28 # trainがTrueの時,trainのパスを指定29 if purpose == 'train':30 image_folder = '/mnt/mizuno/dataset/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train/*/*_leftImg8bit.png'31 self.image_files = glob.glob(image_folder)32 labels_folder = '/mnt/mizuno/dataset/cityscapes/gtFine_trainvaltest/gtFine/train/*/*_gtFine_labelIds.png'33 self.labels_files = glob.glob(labels_folder)34 # trainがFalseの時,testのパスを指定35 else:36 image_folder = '/mnt/mizuno/dataset/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/test/*/*_leftImg8bit.png'37 self.image_files = glob.glob(image_folder)38 labels_folder = '/mnt/mizuno/dataset/cityscapes/gtFine_trainvaltest/gtFine/test/*/*_gtFine_labelIds.png'39 self.labels_files = glob.glob(labels_folder)40 41 # ソートする42 self.image_files.sort()43 self.labels_files.sort()44 45 # ディレクトリまでのパスのリストを生成46 self.image_list.append(image_folder)47 self.labels_list.append(labels_folder)48 49 self.image_file_list = []50 self.labels_file_list = []51 52 # 短い方をsizeの値に合わせるように,アスペクト比を保ったままリサイズする53 def short_side(self, w, h, size):54 # https://github.com/facebookresearch/pytorchvideo/blob/a77729992bcf1e43bf5fa507c8dc4517b3d7bc4c/pytorchvideo/transforms/functional.py#L11855 if w < h:56 new_h = int(math.floor((float(h) / w) * size))57 new_w = size 58 else:59 new_h = size 60 new_w = int(math.floor((float(w) / h) * size))61 return new_w, new_h 62 63 def make_dataset(self, index, image_list, labels_list) -> dict:64 # ランダムなindexの画像を取得65 image_file_path = image_list[index]66 labels_file_path = labels_list[index]67 # target_file_path = target_list[index]68 69 # 開始場所をランダムで指定する70 rand_index = random.randint(71 0, len(image_file_path)72 )73 74 dict_key = ''75 # sampling_target = []76 sampling_image = []77 sampling_labels = []78 path_list = []79 80 seed = random.randint(0, 2**32)81 # for in range(開始値,最大値,間隔)82 for i in range(83 rand_index,84 len(image_file_path) - rand_index,85 186 ):87 88 # shape:H*W欲しい89 # ----image----90 image_path = image_file_path[i]91 label_numpy = imread(image_path)92 image = torch.from_numpy(label_numpy).unsqueeze(0)93 94 # shape:H*Wが欲しい95 # ----labels----96 # まずnumpyにする97 labels_numpy = imread(labels_file_path[i])98 # numpyをテンソルに変換する99 labels = torch.from_numpy(labels_numpy).unsqueeze(0)100 101 # sampling_target.append(target)102 sampling_image.append(image)103 sampling_labels.append(labels)104 105 path_list.append(str(Path(image_path).parent.parent.name))106 107 # リサイズされた画像の縦と横の長さをリスト化する108 # target.size()[1]:縦の長さ,target.size()[2]:横の長さ109 h, w = self.short_side(target.size()[1], target.size()[2], 256)110 transform_list = [111 transforms.Resize([h, w], Image.NEAREST),112 transforms.RandomCrop(self.config.crop_size)113 ]114 115 # transform.Compose:複数のTransformを連続して行うTransform116 transform = transforms.Compose(transform_list)117 118 # image_tensor.size():3*224*224にしたい119 120 # torch.stack:sampling_imageをdim=0の方向に連結121 image_tensor = torch.stack(sampling_image, dim=0)122 torch.manual_seed(seed)123 image_tensor = transform(image_tensor)124 125 # labels_tensor.size():3*224*224にしたい126 labels_tensor = torch.stack(sampling_labels, dim=0)127 torch.manual_seed(seed)128 labels_tensor = transform(labels_tensor)129 130 # dict_key = str(Path(image_path).parent.parent.parent.name)131 # label = self.class_idx_dict[dict_key]132 133 return {134 'image': image_tensor,135 'labels': labels_tensor 136 }137 138 # クラス名とデータセット内の対応するインデックスを対応づけるメソッド139 def classToIdx(self, args):140 class_list = sorted(141 entry.name for entry in os.scandir(args.image)142 if entry.is_dir())143 144 class_to_idx = {cls_name: i for i, cls_name in enumerate(class_list)}145 146 return class_to_idx 147 148 # __getitem__が呼び出された時,index(ランダム引数),image_file_list,label_file_listにして返す149 def __getitem__(self, index) -> dict:150 151 data = self.make_dataset(152 index,153 self.image_file_list,154 self.labels_file_list 155 )156 157 return data 158 159 def __len__(self):160 # 全画像ファイル数を返す161 return len(self.image_file_list)162

試したこと

・https://nsr-9.hatenablog.jp/entry/2021/09/05/100000
を参考にパスを変更
・不要部分を消した

補足情報(FW/ツールのバージョンなど)

ここにより詳細な情報を記載してください。

コメントを投稿

0 コメント