ディレクトリエラーを解決したい

python

1import torch 2import glob 3import numpy as np 4import os 5import subprocess 6 7import torchvision.transforms as transforms 8import torch.nn.functional as F 9from torch import nn 10from models.model import generate_model 11from learner import Learner 12from PIL import Image, ImageFilter 13 14from PIL import Image, ImageFilter, ImageOps, ImageChops 15import numpy as np 16import torch 17import random 18import numbers 19import pdb 20import time 21import cv2 22from matplotlib import pyplot as plt 23from tqdm import tqdm 24import sys 25import argparse 26 27try:28 import accimage 29except ImportError:30 accimage = None31 32 33parser = argparse.ArgumentParser(description='Video Anomaly Detection')34parser.add_argument('--n', default='', type=str, help='file name')35args = parser.parse_args()36 37 38class ToTensor(object):39 40 """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. 41 Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 42 [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 43 """44 45 def __init__(self, norm_value=255):46 self.norm_value = norm_value 47 48 def __call__(self, pic):49 """ 50 Args: 51 pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. 52 Returns: 53 Tensor: Converted image. 54 """55 if isinstance(pic, np.ndarray):56 # handle numpy array57 img = torch.from_numpy(pic.transpose((2, 0, 1)))58 # backward compatibility59 return img.float().div(self.norm_value)60 61 if accimage is not None and isinstance(pic, accimage.Image):62 nppic = np.zeros(63 [pic.channels, pic.height, pic.width], dtype=np.float32)64 pic.copyto(nppic)65 return torch.from_numpy(nppic)66 67 # handle PIL Image68 if pic.mode == 'I':69 img = torch.from_numpy(np.array(pic, np.int32, copy=False))70 elif pic.mode == 'I;16':71 img = torch.from_numpy(np.array(pic, np.int16, copy=False))72 else:73 img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))74 # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK75 if pic.mode == 'YCbCr':76 nchannel = 377 elif pic.mode == 'I;16':78 nchannel = 179 else:80 nchannel = len(pic.mode)81 img = img.view(pic.size[1], pic.size[0], nchannel)82 # put it from HWC to CHW format83 # yikes, this transpose takes 80% of the loading time/CPU84 img = img.transpose(0, 1).transpose(0, 2).contiguous()85 if isinstance(img, torch.ByteTensor):86 return img.float().div(self.norm_value)87 else:88 return img 89 90 def randomize_parameters(self):91 pass92 93 94class Normalize(object):95 def __init__(self, mean, std):96 self.mean = mean 97 self.std = std 98 99 def __call__(self, tensor):100 for t, m, s in zip(tensor, self.mean, self.std):101 t.sub_(m).div_(s)102 return tensor 103 104 def randomize_parameters(self):105 pass106 107#############################################################108# MAIN CODE #109#############################################################110 111model = generate_model() # feature extrctir112classifier = Learner().cuda() # classifier113 114checkpoint = torch.load('./weight/RGB_Kinetics_16f.pth')115model.load_state_dict(checkpoint['state_dict'])116checkpoint = torch.load('./weight/ckpt.pth')117classifier.load_state_dict(checkpoint['net'])118 119model.eval()120classifier.eval()121 122path = args.n + '/*'123save_path = args.n +'_result'124img = glob.glob(path)125img.sort()126 127segment = len(img)//16128x_value =[i for i in range(segment)]129 130inputs = torch.Tensor(1, 3, 16, 240, 320)131x_time = [jj for jj in range(len(img))]132y_pred = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]133for num, i in enumerate(img):134 if num < 16:135 inputs[:,:,num,:,:] = ToTensor(1)(Image.open(i))136 cv_img = cv2.imread(i)137 print(cv_img.shape)138 h,w,_ =cv_img.shape 139 cv_img = cv2.putText(cv_img, 'FPS : 0.0, Pred : 0.0', (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,200,240), 2)140 else:141 inputs[:,:,:15,:,:] = inputs[:,:,1:,:,:]142 inputs[:,:,15,:,:] = ToTensor(1)(Image.open(i))143 inputs = inputs.cuda()144 start = time.time()145 output, feature = model(inputs)146 feature = F.normalize(feature, p=2, dim=1)147 out = classifier(feature)148 y_pred.append(out.item())149 end = time.time()150 FPS = str(1/(end-start))[:5]151 out_str = str(out.item())[:5]152 print(len(x_value)/len(y_pred))153 154 cv_img = cv2.imread(i)155 cv_img = cv2.putText(cv_img, 'FPS :'+FPS+' Pred :'+out_str, (5,15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,200,240), 2)156 if out.item() > 0.4:157 cv_img = cv2.rectangle(cv_img,(0,0),(w,h), (0,0,255), 3)158 159 if not os.path.isdir(save_path):160 os.mkdir(save_path)161 162 path = './'+save_path+'/'+os.path.basename(i)163 cv2.imwrite(path, cv_img)164 165os.system('ffmpeg -i "%s" "%s"'%(save_path+'/%05d.jpg', save_path+'.mp4'))166plt.plot(x_time, y_pred)167plt.savefig(save_path+'.png', dpi=300)168plt.cla()

コメントを投稿

0 コメント