PyTorch 通過帶有 Flask 的 REST API 在 Python 中部署 PyTorch

2020-09-09 14:28 更新
原文: https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html

作者: Avinash Sajjanshetty

在本教程中,我們將使用 Flask 部署 PyTorch 模型,并公開用于模型推理的 REST API。 特別是,我們將部署預訓練的 DenseNet 121 模型來檢測圖像。

小費

此處使用的所有代碼均以 MIT 許可發(fā)布,可在 Github 上找到。

這是在生產(chǎn)中部署 PyTorch 模型的系列教程中的第一篇。 到目前為止,以這種方式使用 Flask 是開始為 PyTorch 模型提供服務的最簡單方法,但不適用于具有高性能要求的用例。 為了那個原因:

  • 如果您已經(jīng)熟悉 TorchScript,可以直接進入我們的用 C ++加載 TorchScript 模型教程。
  • 如果您首先需要在 TorchScript 上進行復習,請查看我們的 TorchScript 入門教程。

API 定義

我們將首先定義 API 端點,請求和響應類型。 我們的 API 端點將位于/predict,該端點通過包含圖片的file參數(shù)接受 HTTP POST 請求。 響應將是包含預測的 JSON 響應:

  1. {"class_id": "n02124075", "class_name": "Egyptian_cat"}

依存關系

通過運行以下命令來安裝所需的依賴項:

  1. $ pip install Flask==1.0.3 torchvision-0.3.0

簡單的 Web 服務器

以下是一個簡單的網(wǎng)絡服務器,摘自 Flask 的文檔

  1. from flask import Flask
  2. app = Flask(__name__)
  3. @app.route('/')
  4. def hello():
  5. return 'Hello World!'

將以上代碼段保存在名為app.py的文件中,您現(xiàn)在可以通過輸入以下內(nèi)容來運行 Flask 開發(fā)服務器:

  1. $ FLASK_ENV=development FLASK_APP=app.py flask run

當您在網(wǎng)絡瀏覽器中訪問http://localhost:5000/時,將看到Hello World!文字

我們將對上面的代碼段進行一些更改,以使其適合我們的 API 定義。 首先,我們將方法重命名為predict。 我們將端點路徑更新為/predict。 由于圖像文件將通過 HTTP POST 請求發(fā)送,因此我們將對其進行更新,使其也僅接受 POST 請求:

  1. @app.route('/predict', methods=['POST'])
  2. def predict():
  3. return 'Hello World!'

我們還將更改響應類型,以使其返回包含 ImageNet 類 ID 和名稱的 JSON 響應。 更新后的app.py文件現(xiàn)在為:

  1. from flask import Flask, jsonify
  2. app = Flask(__name__)
  3. @app.route('/predict', methods=['POST'])
  4. def predict():
  5. return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理

在下一部分中,我們將重點介紹編寫推理代碼。 這將涉及兩部分,第一部分是準備圖像,以便可以將其饋送到 DenseNet;第二部分,我們將編寫代碼以從模型中獲取實際的預測。

準備圖像

DenseNet 模型要求圖像為尺寸為 224 x 224 的 3 通道 RGB 圖像。我們還將使用所需的均值和標準偏差值對圖像張量進行歸一化。 您可以在上閱讀有關它的更多信息。

我們將使用torchvision庫中的transforms并建立一個轉(zhuǎn)換管道,該轉(zhuǎn)換管道可根據(jù)需要轉(zhuǎn)換圖像。 您可以在上閱讀有關變換的更多信息。

  1. import io
  2. import torchvision.transforms as transforms
  3. from PIL import Image
  4. def transform_image(image_bytes):
  5. my_transforms = transforms.Compose([transforms.Resize(255),
  6. transforms.CenterCrop(224),
  7. transforms.ToTensor(),
  8. transforms.Normalize(
  9. [0.485, 0.456, 0.406],
  10. [0.229, 0.224, 0.225])])
  11. image = Image.open(io.BytesIO(image_bytes))
  12. return my_transforms(image).unsqueeze(0)

上面的方法以字節(jié)為單位獲取圖像數(shù)據(jù),應用一系列變換并返回張量。 要測試上述方法,請以字節(jié)模式讀取圖像文件(首先將 <cite>../_static/img/sample_file.jpeg</cite> 替換為計算機上文件的實際路徑),然后查看是否獲得張量 背部:

  1. with open("../_static/img/sample_file.jpeg", 'rb') as f:
  2. image_bytes = f.read()
  3. tensor = transform_image(image_bytes=image_bytes)
  4. print(tensor)

得出:

  1. tensor([[[[ 0.4508, 0.4166, 0.3994, ..., -1.3473, -1.3302, -1.3473],
  2. [ 0.5364, 0.4851, 0.4508, ..., -1.2959, -1.3130, -1.3302],
  3. [ 0.7077, 0.6392, 0.6049, ..., -1.2959, -1.3302, -1.3644],
  4. ...,
  5. [ 1.3755, 1.3927, 1.4098, ..., 1.1700, 1.3584, 1.6667],
  6. [ 1.8893, 1.7694, 1.4440, ..., 1.2899, 1.4783, 1.5468],
  7. [ 1.6324, 1.8379, 1.8379, ..., 1.4783, 1.7352, 1.4612]],
  8. [[ 0.5728, 0.5378, 0.5203, ..., -1.3704, -1.3529, -1.3529],
  9. [ 0.6604, 0.6078, 0.5728, ..., -1.3004, -1.3179, -1.3354],
  10. [ 0.8529, 0.7654, 0.7304, ..., -1.3004, -1.3354, -1.3704],
  11. ...,
  12. [ 1.4657, 1.4657, 1.4832, ..., 1.3256, 1.5357, 1.8508],
  13. [ 2.0084, 1.8683, 1.5182, ..., 1.4657, 1.6583, 1.7283],
  14. [ 1.7458, 1.9384, 1.9209, ..., 1.6583, 1.9209, 1.6408]],
  15. [[ 0.7228, 0.6879, 0.6531, ..., -1.6476, -1.6302, -1.6476],
  16. [ 0.8099, 0.7576, 0.7228, ..., -1.6476, -1.6476, -1.6650],
  17. [ 1.0017, 0.9145, 0.8797, ..., -1.6476, -1.6650, -1.6999],
  18. ...,
  19. [ 1.6291, 1.6291, 1.6465, ..., 1.6291, 1.8208, 2.1346],
  20. [ 2.1868, 2.0300, 1.6814, ..., 1.7685, 1.9428, 2.0125],
  21. [ 1.9254, 2.0997, 2.0823, ..., 1.9428, 2.2043, 1.9080]]]])

預測

現(xiàn)在將使用預訓練的 DenseNet 121 模型來預測圖像類別。 我們將使用torchvision庫中的一個,加載模型并進行推斷。 在此示例中,我們將使用預訓練的模型,但您可以對自己的模型使用相同的方法。

  1. from torchvision import models
  2. ## Make sure to pass `pretrained` as `True` to use the pretrained weights:
  3. model = models.densenet121(pretrained=True)
  4. ## Since we are using our model only for inference, switch to `eval` mode:
  5. model.eval()
  6. def get_prediction(image_bytes):
  7. tensor = transform_image(image_bytes=image_bytes)
  8. outputs = model.forward(tensor)
  9. _, y_hat = outputs.max(1)
  10. return y_hat

張量y_hat將包含預測的類 ID 的索引。 但是,我們需要一個人類可讀的類名。 為此,我們需要一個類 ID 來進行名稱映射。 

  1. import json
  2. imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
  3. def get_prediction(image_bytes):
  4. tensor = transform_image(image_bytes=image_bytes)
  5. outputs = model.forward(tensor)
  6. _, y_hat = outputs.max(1)
  7. predicted_idx = str(y_hat.item())
  8. return imagenet_class_index[predicted_idx]

在使用imagenet_class_index字典之前,首先我們將張量值轉(zhuǎn)換為字符串值,因為imagenet_class_index字典中的鍵是字符串。 我們將測試上述方法:

  1. with open("../_static/img/sample_file.jpeg", 'rb') as f:
  2. image_bytes = f.read()
  3. print(get_prediction(image_bytes=image_bytes))

得出:

  1. ['n02124075', 'Egyptian_cat']

您應該得到如下響應:

  1. ['n02124075', 'Egyptian_cat']

數(shù)組中的第一項是 ImageNet 類 ID,第二項是人類可讀的名稱。

Note

您是否注意到model變量不屬于get_prediction方法? 還是為什么模型是全局變量? 就內(nèi)存和計算而言,加載模型可能是一項昂貴的操作。 如果我們在get_prediction方法中加載模型,則每次調(diào)用該方法時都會不必要地加載該模型。 由于我們正在構建一個 Web 服務器,因此每秒可能有成千上萬的請求,因此我們不應該浪費時間為每個推斷重復加載模型。 因此,我們僅將模型加載到內(nèi)存中一次。 在生產(chǎn)系統(tǒng)中,必須有效利用計算以能夠大規(guī)模處理請求,因此通常應在處理請求之前加載模型。

將模型集成到我們的 API 服務器中

在最后一部分中,我們將模型添加到 Flask API 服務器中。 由于我們的 API 服務器應該獲取圖像文件,因此我們將更新predict方法以從請求中讀取文件:

  1. from flask import request
  2. @app.route('/predict', methods=['POST'])
  3. def predict():
  4. if request.method == 'POST':
  5. # we will get the file from the request
  6. file = request.files['file']
  7. # convert that to bytes
  8. img_bytes = file.read()
  9. class_id, class_name = get_prediction(image_bytes=img_bytes)
  10. return jsonify({'class_id': class_id, 'class_name': class_name})

app.py文件現(xiàn)在完成。 以下是完整版本; 將路徑替換為保存文件的路徑,它應運行:

  1. import io
  2. import json
  3. from torchvision import models
  4. import torchvision.transforms as transforms
  5. from PIL import Image
  6. from flask import Flask, jsonify, request
  7. app = Flask(__name__)
  8. imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
  9. model = models.densenet121(pretrained=True)
  10. model.eval()
  11. def transform_image(image_bytes):
  12. my_transforms = transforms.Compose([transforms.Resize(255),
  13. transforms.CenterCrop(224),
  14. transforms.ToTensor(),
  15. transforms.Normalize(
  16. [0.485, 0.456, 0.406],
  17. [0.229, 0.224, 0.225])])
  18. image = Image.open(io.BytesIO(image_bytes))
  19. return my_transforms(image).unsqueeze(0)
  20. def get_prediction(image_bytes):
  21. tensor = transform_image(image_bytes=image_bytes)
  22. outputs = model.forward(tensor)
  23. _, y_hat = outputs.max(1)
  24. predicted_idx = str(y_hat.item())
  25. return imagenet_class_index[predicted_idx]
  26. @app.route('/predict', methods=['POST'])
  27. def predict():
  28. if request.method == 'POST':
  29. file = request.files['file']
  30. img_bytes = file.read()
  31. class_id, class_name = get_prediction(image_bytes=img_bytes)
  32. return jsonify({'class_id': class_id, 'class_name': class_name})
  33. if __name__ == '__main__':
  34. app.run()

讓我們測試一下我們的網(wǎng)絡服務器! 跑:

  1. $ FLASK_ENV=development FLASK_APP=app.py flask run

我們可以使用請求庫將 POST 請求發(fā)送到我們的應用:

  1. import requests
  2. resp = requests.post("http://localhost:5000/predict",
  3. files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

現(xiàn)在打印 <cite>resp.json()</cite>將顯示以下內(nèi)容:

  1. {"class_id": "n02124075", "class_name": "Egyptian_cat"}

下一步

我們編寫的服務器非?,嵥椋赡軣o法完成生產(chǎn)應用程序所需的一切。 因此,您可以采取一些措施來改善它:

  • 端點/predict假定請求中始終會有一個圖像文件。 并非所有請求都適用。 我們的用戶可能發(fā)送帶有其他參數(shù)的圖像,或者根本不發(fā)送任何圖像。
  • 用戶也可以發(fā)送非圖像類型的文件。 由于我們沒有處理錯誤,因此這將破壞我們的服務器。 添加顯式的錯誤處理路徑將引發(fā)異常,這將使我們能夠更好地處理錯誤的輸入
  • 即使模型可以識別大量類別的圖像,也可能無法識別所有圖像。 增強實現(xiàn)以處理模型無法識別圖像中的任何情況的情況。
  • 我們在開發(fā)模式下運行 Flask 服務器,該服務器不適合在生產(chǎn)中進行部署。 您可以查看本教程的,以在生產(chǎn)環(huán)境中部署 Flask 服務器。
  • 您還可以通過創(chuàng)建一個帶有表單的頁面來添加 UI,該表單可以拍攝圖像并顯示預測。 查看類似項目的演示及其源代碼
  • 在本教程中,我們僅展示了如何構建可以一次返回單個圖像預測的服務。 我們可以修改我們的服務,以便能夠一次返回多個圖像的預測。 此外,服務流媒體庫會自動將對服務的請求排隊,并將請求采樣到微型批次中,這些微型批次可輸入模型中。 您可以查看本教程。
  • 最后,我們鼓勵您在頁面頂部查看鏈接到的有關部署 PyTorch 模型的其他教程。

腳本的總運行時間:(0 分鐘 1.971 秒)

Download Python source code: flask_rest_api_tutorial.py Download Jupyter notebook: flask_rest_api_tutorial.ipynb

由獅身人面像畫廊生成的畫廊


以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號