restful方式部署pytorch模型


需求

图片分类,模型训练已经完成,需要提供服务。

环境

requirement.txt

1
2
3
4
5
6
Flask==1.1.1
gevent==1.4.0
gunicorn==19.9.0
Pillow==6.2.1
pytorch==1.3.0
torchvision==0.4.0

gunicorn.cong.py

1
2
3
workers = 5
worker_class = "gevent"
bind = "0.0.0.0:8888"

加载模型

1
2
3
4
5
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 6)
model.load_state_dict(torch.load("./models/epoch_18.pth"))
model.eval()

图片预处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def prepare_image(image):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])

# 返回 PIL.Image.Image 对象
image = Image.open(io.BytesIO(image))
if image.mode != 'RGB':
image = image.convert("RGB")
px = my_transforms(image).unsqueeze(0) # 通过unsqueeze(0) 后:torch.Size([3, 224, 224])->torch.Size([1, 3, 224, 224])
return px

接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@app.route("/predict", methods=["POST"])
def predict():
data = {"success": False}

if request.method == 'POST':
image = request.files["image"].read()
image = prepare_image(image)
outputs = model.forward(image)
_, label = outputs.max(1)
# 分类需和训练时的顺序保持一致
class_name = classify[label.item()]
if class_name is not None:
data["success"] = True
data['predictions'] = class_name
return jsonify(data)

参考:https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html

文章搬运:Deploying PyTorch in Python via a REST API with Flask