pytorch版本PSEnet训练并部署方式(pytorch训练好的模型如何部署)深度揭秘

随心笔谈2年前发布 admin
198 0 0

文章摘要

这篇文章展示了一个基于PyTorch构建的计算机视觉系统,用于从图像中检测和识别特定区域。系统的主要组件包括数据预处理、模型加载和推理。文章详细描述了如何使用OpenCV进行图像预处理,如何加载和加载模型参数,并如何使用自定义函数跟踪推理速度。最后,系统在测试目录中处理图像,调用模型进行推理,并将检测到的区域用红色矩形框出保存。文章展示了从图像预处理到结果保存的完整流程。

import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def prepare_image(image, target_size):
“””Do image preprocessing before prediction on any data.
:param image: original image
:param target_size: target image size
:return:
preprocessed image
“””
#assert os.path.exists(img), ‘file is not exists’
#img=cv2.imread(img)
img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# h, w=image.shape[:2]
# scale=long_size / max(h, w)
img=cv2.resize(img, target_size)
# 将图片由(w,h)变为(1,img_channel,h,w)
tensor=transforms.ToTensor()(img)
tensor=tensor.unsqueeze_(0)
tensor=tensor.to(torch.device(“cuda:0″))
return tensor
def report_speed(outputs, speed_meters):
total_time=0
for key in outputs:
if ‘time’ in key:
total_time +=outputs[key]
speed_meters[key].update(outputs[key])
print(‘%s: %.4f’ % (key, speed_meters[key].avg))
speed_meters[‘total_time’].update(total_time)
print(‘FPS: %.1f’ % (1.0 / speed_meters[‘total_time’].avg))
def load_model(cfg):
model=build_model(cfg.model)
model=model.cuda()
model.eval()
checkpoint=”psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar”
if checkpoint is not None:
if os.path.isfile(checkpoint):
print(“Loading model and optimizer from checkpoint ‘{}'”.format(checkpoint))
sys.stdout.flush()
checkpoint=torch.load(checkpoint)
d=dict()
for key, value in checkpoint[‘state_dict’].items():
tmp=key[7:]
d[tmp]=value
model.load_state_dict(d)
else:
print(“No checkpoint found at”)
raise
# fuse conv and bn
model=fuse_module(model)
return model
if __name__==’__main__’:
src_dir=”testimg/”
save_dir=”test_save/”
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cfg=Config.fromfile(“PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py”)
for d in [cfg, cfg.data.test]:
d.update(dict(
report_speed=False
))
if cfg.report_speed:
speed_meters=dict(
backbone_time=AverageMeter(500),
neck_time=AverageMeter(500),
det_head_time=AverageMeter(500),
det_pse_time=AverageMeter(500),
rec_time=AverageMeter(500),
total_time=AverageMeter(500)
)
model=load_model(cfg)
model.eval()
count=0
for img_name in os.listdir(src_dir):
img=cv2.imread(src_dir + img_name)
tensor=prepare_image(img, target_size=(1376, 1024))
data=dict()
img_metas=dict()
data[‘imgs’]=tensor
img_metas[‘org_img_size’]=torch.tensor([[img.shape[0], img.shape[1]]])
img_metas[‘img_size’]=torch.tensor([[1376, 1024]])
data[‘img_metas’]=img_metas
data.update(dict(
cfg=cfg
))
with torch.no_grad():
outputs=model(**data)
if cfg.report_speed:
report_speed(outputs, speed_meters)
for bboxes in outputs[‘bboxes’]:
x1=bboxes[0]
y1=bboxes[1]
x2=bboxes[4]
y2=bboxes[5]
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
count=count + 1
cv2.imwrite(save_dir + img_name, img)
print(“img test:”, count)

© 版权声明

相关文章