需求
图片分类,模型训练已经完成,需要提供服务。
环境
requirement.txt
1 2 3 4 5 6
| Flask==1.1.1 gevent==1.4.0 gunicorn==19.9.0 Pillow==6.2.1 pytorch==1.3.0 torchvision==0.4.0
|
gunicorn.cong.py
1 2 3
| workers = 5 worker_class = "gevent" bind = "0.0.0.0:8888"
|
加载模型
1 2 3 4 5
| model = models.resnet18(pretrained=False) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 6) model.load_state_dict(torch.load("./models/epoch_18.pth")) model.eval()
|
图片预处理
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| def prepare_image(image): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image)) if image.mode != 'RGB': image = image.convert("RGB") px = my_transforms(image).unsqueeze(0) return px
|
接口
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| @app.route("/predict", methods=["POST"]) def predict(): data = {"success": False}
if request.method == 'POST': image = request.files["image"].read() image = prepare_image(image) outputs = model.forward(image) _, label = outputs.max(1) class_name = classify[label.item()] if class_name is not None: data["success"] = True data['predictions'] = class_name return jsonify(data)
|
参考:https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html
文章搬运:Deploying PyTorch in Python via a REST API with Flask