Python
1import streamlit as st 2import numpy as np 3from PIL import Image 4import cv2 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8import torchvision.transforms as transforms 9 10class CNN(nn.Module):11 12 def __init__(self):13 super(CNN, self).__init__()14 self.cn1 = nn.Conv2d(3, 6, 5)15 self.pool1 = nn.MaxPool2d(2, 2)16 self.cn2 = nn.Conv2d(6, 16, 5)17 self.pool2 = nn.MaxPool2d(2, 2)18 self.cn3 = nn.Conv2d(16, 32, 4)19 self.dropout = nn.Dropout2d()20 self.fc1 = nn.Linear(32*10*10, 120)21 self.fc2 = nn.Linear(120, 84)22 self.fc3 = nn.Linear(84, 5)23 24 def forward(self, x):25 x = F.relu(self.cn1(x))26 x = self.pool1(x)27 x = F.relu(self.cn2(x))28 x = self.pool2(x)29 x = F.relu(self.cn3(x))30 x = self.dropout(x)31 x = x.view(-1, 32*10*10)32 x = F.relu(self.fc1(x))33 x = F.relu(self.fc2(x))34 x = self.fc3(x)35 36 return x 37 38#読み込んだ画像の中からウマ娘の顔を検出し,名前とBoxを描画する関数39def detect(image, model):40 #顔検出器の準備41 classifier = cv2.CascadeClassifier("haarcascade_frontalface_alt.xml")42 43 #画像をグレースケール化44 gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)45 #画像の中から顔を検出46 faces = classifier.detectMultiScale(gray_image, scaleFactor = 1.0001)47 #1人以上の顔を検出した場合48 if len(faces)>0:49 for face in faces:50 x, y, width, height = face 51 detect_face = image[y:y+height, x:x+width]52 detect_face = cv2.resize(detect_face, (64, 64))53 if detect_face.shape[0] < 64:54 print("tuuka")55 continue56 detect_face = cv2.resize(detect_face, (64,64))57 transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])58 detect_face = transform(detect_face)59 detect_face = detect_face.view(1,3,64,64)60 61 output = model(detect_face)62 print(f"output{output}")63 name_type = output.argmax(dim=1, keepdim=True)64 print(name_type)65 name = type_to_name(name_type)66 else:67 name = "Impossible to detect"68 return name 69 70 71def type_to_name(name_type):72 if name_type == 0:73 name = "HYBE"74 elif name_type == 1:75 name = "JYP"76 elif name_type == 2:77 name = "SM"78 elif name_type == 3:79 name = "STARSHIP"80 elif name_type == 4:81 name = "YG"82 83 return name 84 85 86def main():87 st.set_page_config(layout="centered")88 #タイトルの表示89 st.title("KPOP事務所顔診断")90 #アプリの説明の表示91 st.markdown("韓国の事務所別の顔を識別するアプリです")92 93 image = st.file_uploader("画像をアップロードしてください", type=['jpg','jpeg', 'png'])94 is_men = st.radio("性別を選択", ("男性", "女性"), horizontal=True, args=[1, 0])95 st.button('判定結果を見る', on_click=change_page)96 97 if is_men == "男性":98 sex = '男性'99 else:100 sex = '女性'101 102 if sex == '男性':103 model = CNN()104 checkpoint = torch.load("./men_cnn.pt", map_location=torch.device('cpu'))105 model.load_state_dict(checkpoint['model_state_dict'])106 model.eval()107 if image != None:108 109 #画像の読み込み110 image = np.array(Image.open(image))111 #画像からウマ娘の顔検出を行う112 detect_name = detect(image, model)113 st.session_state['output'] = detect_name 114 print(f'{detect_name}')115 print('男性を選択')116 117 elif sex == '女性':118 model = CNN()119 checkpoint = torch.load("./women_cnn.pt", map_location=torch.device('cpu'))120 model.load_state_dict(checkpoint['model_state_dict'])121 model.eval()122 if image != None:123 124 #画像の読み込み125 image = np.array(Image.open(image))126 #画像からウマ娘の顔検出を行う127 detect_name = detect(image, model)128 st.session_state['output'] = detect_name 129 print(f'{detect_name}')130 print('女性を選択') 131 132def next_page():133 st.title('KPOP事務所顔診断')134 path = f"./output_img/{st.session_state['output']}_out.png"135 img = np.array(Image.open(path))136 st.image(img)137 st.button('戻る', on_click=back_page())138 139def change_page():140 st.session_state['page_control'] = 1141 142def back_page():143 st.session_state['page_control'] = 0144 st.session_state['output'] = None145 146 147 148 149if __name__ == "__main__":150 if ("page_control" in st.session_state and st.session_state["page_control"] == 1):151 next_page()152 else:153 st.session_state["page_control"] = 0154 main()
0 コメント