pythonのエラーについての質問です。
画像を読み込んで100×100に分割し、その画像と教師ラベルを使用してCNNで学習をしたいです。
その際に、kerasのImageDataGeneratorクラスを用いて画像を学習直前で変換したいのですが、下記エラーが出てしまいうまくいきません。
また、環境は次のとおりです
windows10(64bit)
anaconda 4.13.0
python 3.7.13
TensorFlow 2.3
Spyder 5.1.5を利用
エラー
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
ソースコード
# -*- coding: utf-8 -*- # coding: utf-8 import numpy as np from PIL import Image import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras.models import Model from tensorflow.keras.preprocessing import image from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.optimizers import Adam def kiritori(img): img_array=np.array(img,dtype='int16') #(50行,83列,3次元)の3次元配列 a=np.split(img_array,img.height/100) img_data=np.split(a[0],img.width/100,axis=1)#stackする土台づくり for i in range(len(a)-1): img_data=np.vstack([img_data,np.split(a[i+1],img.width/100,axis=1)]) return img_data x_train=kiritori(Image.open('3Gatu_Data/DJI_0730.JPG'))#(#(1200,100,100,3) train_label=np.loadtxt(fname='label/3Gatu_730_100.csv',delimiter=',',dtype='int16')#(30, 40) train_label=np.reshape(train_label,(train_label.size)) #(1200) rati_train_data =x_train[np.where(train_label==0)] hamagou_train_data =x_train[np.where(train_label==1)] kouboumugi_train_data =x_train[np.where(train_label==2)] x_train=kiritori(Image.open('3Gatu_Data/DJI_0777.JPG'))#(1200,100,100,3) train_label=np.loadtxt(fname='label/3Gatu_777_100.csv',delimiter=',',dtype='int16')#(30,40) train_label=np.reshape(train_label,(train_label.size)) #(1200) rati_train_data =np.vstack([rati_train_data,x_train[np.where(train_label==0)]]) hamagou_train_data =np.vstack([hamagou_train_data,x_train[np.where(train_label==1)]]) kouboumugi_train_data =np.vstack([kouboumugi_train_data,x_train[np.where(train_label==2)]]) x_train=kiritori(Image.open('3Gatu_Data/DJI_0908.JPG'))#(1200,100,100,3) train_label=np.loadtxt(fname='label/3Gatu_908_100.csv',delimiter=',',dtype='int16')#(30,40) train_label=np.reshape(train_label,(train_label.size)) #(1200) rati_train_data =np.vstack([rati_train_data,x_train[np.where(train_label==0)]]) hamagou_train_data =np.vstack([hamagou_train_data,x_train[np.where(train_label==1)]]) kouboumugi_train_data =np.vstack([kouboumugi_train_data,x_train[np.where(train_label==2)]]) x_train=rati_train_data x_train=np.vstack([x_train,hamagou_train_data]) x_train=np.vstack([x_train,kouboumugi_train_data]) val_test=kiritori(Image.open('3Gatu_Data/DJI_0957.JPG')) test_label=np.loadtxt(fname='label/3Gatu_957_100.csv',delimiter=',',dtype='int16')#(30,40) test_label=np.reshape(test_label,(test_label.size)) #(1200) rati_test_data =val_test[np.where(test_label==0)] hamagou_test_data =val_test[np.where(test_label==1)] kouboumugi_test_data =val_test[np.where(test_label==2)] val_test=rati_test_data val_test=np.vstack([val_test,hamagou_test_data]) val_test=np.vstack([val_test,kouboumugi_test_data]) test_data=Image.open('3Gatu_Data/DJI_0957.JPG') copy_back=test_data.copy() x_test=kiritori(test_data) y_train=np.full((rati_train_data.shape[0],1),0) y_train=np.vstack([y_train,np.full((hamagou_train_data.shape[0],1),1)]) y_train=np.vstack([y_train,np.full((kouboumugi_train_data.shape[0],1),2)]) y_test=np.full((rati_test_data.shape[0],1),0) y_test=np.vstack([y_test,np.full((hamagou_test_data.shape[0],1),1)]) y_test=np.vstack([y_test,np.full((kouboumugi_test_data.shape[0],1),2)]) x_train = x_train/255. x_test = x_test/255. from tensorflow.keras.utils import to_categorical y_train = to_categorical(y_train,3) y_test = to_categorical(y_test,3) x_train_datagen = ImageDataGenerator( rotation_range=90, vertical_flip=True, brightness_range = [0.7, 1.117], horizontal_flip=True, ) x_train_generator = x_train_datagen.flow(x_train, subset="training", batch_size=10 ) val_datagen = ImageDataGenerator(rescale=1.0 / 255) val_test_generator = val_datagen.flow(val_test, subset="validation", batch_size=10 ) from tensorflow.python.keras.models import Sequential model = Sequential() from tensorflow.python.keras.layers import Conv2D model.add( #畳み込み Conv2D( filters=6, #出力 input_shape=(100,100,3), kernel_size=(3,3), #フィルタサイズ strides=(1,1), padding='valid', activation='relu' ) ) from tensorflow.python.keras.layers import MaxPooling2D from tensorflow.python.keras.layers import Dropout model.add(MaxPooling2D(pool_size=(2,2))) #マックスプーリング model.add(Dropout(0.25)) #ドロップアウト from tensorflow.python.keras.layers import Flatten model.add(Flatten()) #2次元配列に from tensorflow.python.keras.layers import Dense model.add(Dense(units=3456,activation='relu')) #全結合 from tensorflow.python.keras.layers import Dropout model.add(Dropout(0.5)) model.add(Dense(units=3,activation='softmax',name="f3")) #全結合 model.summary() model.compile( optimizer='adam', #自動で学習率が設定される loss='categorical_crossentropy', #多分類のときにしていできる交差エントロピー metrics=['accuracy'] ) history_model = model.fit_generator( x_train_generator, y_train, epochs=2, steps_per_epoch = len(x_train)/20, validation_data=(val_test_generator, y_test), validation_steps = len(val_test)/20, shuffle=True, ) #精度のアウトプット probs=model.predict(x_test) #学習したパラメータで精度を出力 probs_max=np.reshape(np.argmax(probs,axis=1),(int(test_data.height/100),int(test_data.width/100)))#(30,40)2次元配列 np.savetxt('Output_file/cnn_output_3Gatu_100_500.txt',probs_max,fmt="%d",delimiter=",") #モデルの保存 model_json_str = model.to_json() open('Output_file/cnn_model_3Gatu_100_500.json', 'w').write(model_json_str) model.save_weights('Output_file/mnist_mlp_3Gatu_100_500_weights.h5'); score = model.evaluate(val_test, y_test, verbose=0) print('Test loss :', score[0]) print('Test accuracy :', score[1]) loss = history_model.history['loss'] #訓練データの誤差 val_loss = history_model.history['val_loss']#テストデータ誤差 accuracy = history_model.history['accuracy'] #訓練データの誤差 val_accuracy = history_model.history['val_accuracy']#テストデータ誤差 np.savetxt('Output_file/cnn_loss_3Gatu_100.csv',loss,fmt="%.6f",delimiter=",") np.savetxt('Output_file/cnn_val_loss_3Gatu_100.csv',val_loss,fmt="%.6f",delimiter=",") np.savetxt('Output_file/cnn_acc_3Gatu_100.csv',accuracy,fmt="%.6f",delimiter=",") np.savetxt('Output_file/cnn_val_acc_3Gatu_100.csv',val_accuracy,fmt="%.6f",delimiter=",")
すみませんが,どなたかご教授頂けると幸いです。
よろしくお願いします。
0 コメント