Google ColaboratoryでStyleGANを実装したい

前提・実現したいこと

Google ColaboratoryでStyleGANを実装しようとしています。(初学者です)
generatorをロードするのurlを変更することで他の画像も生成できると記載があった為、ImageNetデータセットで学習する予定です。(imagenetを用いたstyleganがstyleganだけのものがなく、stylegan+clipなどしかなかったため)

generatorのロードの部分でAttributeError: module 'config' has no attribute 'cache_dir'というエラーがでてしまいます(このエラーについてググってみましたが、どうすれば解決するのかよくわかりませんでした)。

以下の記事を参考にして進めていましたが、わかりませんでした。
https://teratail.com/questions/295390
https://qiita.com/pacifinapacific/items/1d6cca0ff4060e12d336
http://cedro3.com/ai/stylegan/
https://qiita.com/Phoeboooo/items/12d21916de56d125f0be

発生している問題・エラーメッセージ

Python

AttributeError Traceback (most recent call last)<ipython-input-9-c2a955d99cdc> in <module> 46 47 if __name__ == "__main__":---> 48 main() <ipython-input-9-c2a955d99cdc> in main() 14 # Load pre-trained network. 15 url = 'https://drive.google.com/file/d/1k_H6S-ePszz73lVCZrFRneaV6cbqerTm/view?usp=sharing' # karras2019stylegan-ffhq-1024x1024.pkl---> 16 with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: 17 _G, _D, Gs = pickle.load(f) 18 # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. AttributeError: module 'config' has no attribute 'cache_dir'

該当のソースコード

Python

#git clone でStyleGANのコードを使えるようにする!git clone https://github.com/NVlabs/stylegan.git !pip install https://github.com/podgorskiy/dnnlib/releases/download/0.0.1/dnnlib-0.0.1-py3-none-any.whl #ディレクトリ移動!cd stylegan !pip install tensorflow==1.15.0!pip install tensorflow-gpu==1.15.0import os import pickle import numpy as np import PIL.Image import dnnlib import dnnlib.tflib as tflib !pip install config import config def main(): # Initialize TensorFlow. tflib.init_tf() # Load pre-trained network. url = 'https://drive.google.com/file/d/1k_H6S-ePszz73lVCZrFRneaV6cbqerTm/view?usp=sharing' # karras2019stylegan-ffhq-1024x1024.pkl with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: _G, _D, Gs = pickle.load(f) # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. # Print network details. Gs.print_layers() # Pick latent vector. rnd = np.random.RandomState(10) # seed = 10 latents0 = rnd.randn(1, Gs.input_shape[1]) latents1 = rnd.randn(1, Gs.input_shape[1]) latents2 = rnd.randn(1, Gs.input_shape[1]) latents3 = rnd.randn(1, Gs.input_shape[1]) latents4 = rnd.randn(1, Gs.input_shape[1]) latents5 = rnd.randn(1, Gs.input_shape[1]) latents6 = rnd.randn(1, Gs.input_shape[1]) num_split = 39 # 2つのベクトルを39分割 for i in range(40): latents = latents6+(latents0-latents6)*i/num_split # Generate image. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) # Save image. os.makedirs(config.result_dir, exist_ok=True) png_filename = os.path.join(config.result_dir, 'photo'+'{0:04d}'.format(i)+'.png') PIL.Image.fromarray(images[0], 'RGB').save(png_filename) if __name__ == "__main__": main()

補足情報(FW/ツールのバージョンなど)

参考にしたサイトのように公式のコードそのままでも実行することはできませんでした。

コメントを投稿

0 コメント