1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| """ Created on Sun Sep 29 19:55:17 2019 % author:cuixingxing % email: cuixingxing150@gmail.com % 2019.9.29 %
@author: Administrator """
import torch import torchvision from torchvision import transforms import cv2
size = (224,224) labelpath = r'./synset_words.txt'
def readLabels(labelpath=labelpath): with open(labelpath) as fid: labels = fid.readlines() return labels def classifyImg(net,img): net.to('cuda').eval() img = cv2.resize(img,size) trans = transforms.Compose([transforms.ToTensor()]) tensorImg = trans(img).unsqueeze(0).to('cuda') out = net(tensorImg) scores,idxs = torch.sort(out,dim=1,descending=True) labels = readLabels(labelpath) for i in range(1): print("top {:d},predict score:{:.5f},label:{:s}".format(i,scores[0][i],labels[idxs[0][i]]))
if __name__=='__main__': model = torchvision.models.resnet18(True) cap = cv2.VideoCapture(0) isRead,img = cap.read() while isRead: isRead,img = cap.read() classifyImg(model,img) cv2.imshow("",img) key = cv2.waitKey(10) if key==27: break if key ==' ': cv2.waitKey() torch.onnx.export(model,torch.rand(1,3,224,224).to('cuda'),'resnet18_Torch.onnx',verbose=True)
|