Keras 使用 ResNet 模型進(jìn)行實(shí)時(shí)預(yù)測

2021-11-03 14:19 更新

ResNet是一個(gè)預(yù)訓(xùn)練模型。它使用 ImageNet 進(jìn)行訓(xùn)練。在 ImageNet 上預(yù)訓(xùn)練的 ResNet 模型權(quán)重。它具有以下語法:

  1. keras.applications.resnet.ResNet50 (
  2. include_top = True,
  3. weights = 'imagenet',
  4. input_tensor = None,
  5. input_shape = None,
  6. pooling = None,
  7. classes = 1000
  8. )
  • include_top 指的是網(wǎng)絡(luò)頂部的全連接層。
  • weights 指的是 ImageNet 上的預(yù)訓(xùn)練。
  • input_tensor 指用作模型的圖像輸入的可選的 Keras 張量。
  • input_shape 指可選的形狀元組。此模型的默認(rèn)輸入大小為 224x224。
  • clasees 指用于對(duì)圖像進(jìn)行分類的可選數(shù)量的類。

讓我們通過寫一個(gè)簡單的例子來理解模型:

第 1 步:導(dǎo)入模塊

加載如下指定的必要模塊:

  1. >>> import PIL
  2. >>> from keras.preprocessing.image import load_img
  3. >>> from keras.preprocessing.image import img_to_array
  4. >>> from keras.applications.imagenet_utils import decode_predictions
  5. >>> import matplotlib.pyplot as plt
  6. >>> import numpy as np
  7. >>> from keras.applications.resnet50 import ResNet50
  8. >>> from keras.applications import resnet50

第 2 步:選擇一個(gè)輸入

選擇一個(gè)輸入圖像,Lotus,如下所示:

  1. >>> filename = 'banana.jpg'
  2. >>> ## load an image in PIL format
  3. >>> original = load_img(filename, target_size = (224, 224))
  4. >>> print('PIL image size',original.size)
  5. PIL image size (224, 224)
  6. >>> plt.imshow(original)
  7. <matplotlib.image.AxesImage object at 0x1304756d8>
  8. >>> plt.show()

在這里,我們加載了一個(gè)圖像(banana.jpg)并顯示了它。

第 3 步:將圖像轉(zhuǎn)換為 NumPy 數(shù)組

將輸入的 Banana 轉(zhuǎn)換為 NumPy 數(shù)組,以便將其傳遞到模型中以進(jìn)行預(yù)測。

  1. >>> #convert the PIL image to a numpy array
  2. >>> numpy_image = img_to_array(original)
  3. >>> plt.imshow(np.uint8(numpy_image))
  4. <matplotlib.image.AxesImage object at 0x130475ac8>
  5. >>> print('numpy array size',numpy_image.shape)
  6. numpy array size (224, 224, 3)
  7. >>> # Convert the image / images into batch format
  8. >>> image_batch = np.expand_dims(numpy_image, axis = 0)
  9. >>> print('image batch size', image_batch.shape)
  10. image batch size (1, 224, 224, 3)
  11. >>>

第 4 步:模型預(yù)測

將輸入輸入模型以獲得預(yù)測

  1. >>> prepare the image for the resnet50 model >>>
  2. >>> processed_image = resnet50.preprocess_input(image_batch.copy())
  3. >>> # create resnet model
  4. >>>resnet_model = resnet50.ResNet50(weights = 'imagenet')
  5. >>> Downloavding data from https://github.com/fchollet/deep-learning-models/releas
  6. es/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5
  7. 102858752/102853048 [==============================] - 33s 0us/step
  8. >>> # get the predicted probabilities for each class
  9. >>> predictions = resnet_model.predict(processed_image)
  10. >>> # convert the probabilities to class labels
  11. >>> label = decode_predictions(predictions)
  12. Downloading data from https://storage.googleapis.com/download.tensorflow.org/
  13. data/imagenet_class_index.json
  14. 40960/35363 [==================================] - 0s 0us/step
  15. >>> print(label)

輸出

  1. [
  2. [
  3. ('n07753592', 'banana', 0.99229723),
  4. ('n03532672', 'hook', 0.0014551596),
  5. ('n03970156', 'plunger', 0.0010738898),
  6. ('n07753113', 'fig', 0.0009359837) ,
  7. ('n03109150', 'corkscrew', 0.00028538404)
  8. ]
  9. ]

模型就可以正確地將圖像預(yù)測為 banana。

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)