Использование Mask-RCNN с Flask и TF 2.5.0 rc3

Мне поручили интегрировать чью-то модель в webapp. Они использовали библиотеку под названием Mask-RCNN. Я получаю некоторые ошибки (ниже) при обращении к конечной точке для этой модели. Я видел различные решения для старых версий TensorFlow, но таких решений нет ни в 2.x, ни в 2.5.

Минимальное воспроизведение с Flask ниже. Та же проблема в Django.

# app.py
# run with: flask run
from flask import Flask

app = Flask(__name__)

import tensorflow as tf
import numpy as np
 
from mrcnn import model as modellib
from mrcnn.config import Config
import cv2

class BaseConfig(Config):
    # give the configuration a recognizable name
    NAME = "IMGs"

    # set the number of GPUs to use training along with the number of
    # images per GPU (which may have to be tuned depending on how
    # much memory your GPU has)
    GPU_COUNT = 1
    IMAGES_PER_GPU = 2

    # number of classes (+1 for the background)
    NUM_CLASSES = 1 + 1

    # Most objects possible in an image
    TRAIN_ROIS_PER_IMAGE = 100

class InferenceConfig(BaseConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

    # Most objects possible in an image
    DETECTION_MAX_INSTANCES = 200

    DETECTION_MIN_CONFIDENCE = 0.7

def load_model_for_inference(weights_path):
    """Initialize a Mask R-CNN model with our InferenceConfig and the specified weights"""
    inference_config = InferenceConfig()
    model = modellib.MaskRCNN(mode="inference", config=inference_config, model_dir=".")
    model.load_weights(weights_path, by_name=True)
    # model.keras_model._make_predict_function()
    return model

detection_model = load_model_for_inference("mask_rcnn_0427.h5")

@app.route('/detect')
def ic_model_endpoint():
    image = cv2.imread('img.jpg')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    results = detection_model.detect([image], verbose=1)
    print(results)

@app.route("/")
def hello_world():
    return "<p>Hello, World!</p>"

Использую форк на https://github.com/sabderra/Mask_RCNN, но там нельзя публиковать вопросы, поэтому я решил, что могу спросить здесь. Это форк Mask RCNN, который был обновлен для TF 2, но я думаю, может быть есть еще одно обновление, необходимое для 2.5? Например, должна ли она производить функцию?

Нажатие на конечную точку /detect дает следующее.

ValueError: Tensor Tensor("mrcnn_detection/Reshape_1:0", shape=(1, 200, 6), dtype=float32) is not an element of this graph.

Если мы разкомментируем model.keras_model._make_predict_function(), то получим следующее.

tensorflow.python.framework.errors_impl.InvalidArgumentError: Tensor input_image:0, specified in either feed_devices or fetch_devices was not found in the Graph

Теперь, я видел много решений с использованием сессии, но в этой версии больше нет сессий. Также многие говорят об использовании tf.get_default_graph(), но этого вызова также не существует. Я что-то не понимаю в TensorFlow 2 и в том, как взаимодействуют асинхронные функции

Любое руководство будет оценено по достоинству!

Вернуться на верх