Model.predict зависает в celery/uwsgi

import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from apps.common.utils.error_handling import suppress_callable_to_sentry
from django.conf import settings
from threading import Lock

MODEL_PATH = settings.BASE_DIR / "apps/core/utils/nsfw_detector/nsfw.299x299.h5"
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={"KerasLayer": hub.KerasLayer}, compile=False)

IMAGE_DIM = 299
TOTAL_THRESHOLD = 0.9
INDIVIDUAL_THRESHOLD = 0.7


predict_lock = Lock()

@suppress_callable_to_sentry(Exception, return_value=False)
def is_nsfw(image):
    if image.mode == "RGBA":
        image = image.convert("RGB")
    image = image.resize((IMAGE_DIM, IMAGE_DIM))
    image = np.array(image) / 255.0
    image = np.expand_dims(image, axis=0)
    with predict_lock:
        preds = model.predict(image)[0]
    categories = ["drawings", "hentai", "neutral", "porn", "sexy"]
    probabilities = {cat: float(pred) for cat, pred in zip(categories, preds)}
    individual_nsfw_prob = max(probabilities["porn"], probabilities["hentai"], probabilities["sexy"])
    total_nsfw_prob = probabilities["porn"] + probabilities["hentai"] + probabilities["sexy"]
    return (individual_nsfw_prob > INDIVIDUAL_THRESHOLD) or (total_nsfw_prob > TOTAL_THRESHOLD)

Это работает из python shell и django shell, но застревает в predict part в uwsgi и в celery, у кого-нибудь есть какие-либо идеи, почему это может происходить?

Я поставил кучу точек останова, и проблема в самом предсказании, в shell оно возвращается через ~100 мс, но зависает в uwsgi и celery более чем на 10 минут (дольше не пробовал, так как думаю, очевидно, что оно не вернется)

Пробовал с блокировкой и без нее, результат тот же

Вернуться на верх