Model.predict зависает в celery/uwsgi
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from apps.common.utils.error_handling import suppress_callable_to_sentry
from django.conf import settings
from threading import Lock
MODEL_PATH = settings.BASE_DIR / "apps/core/utils/nsfw_detector/nsfw.299x299.h5"
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={"KerasLayer": hub.KerasLayer}, compile=False)
IMAGE_DIM = 299
TOTAL_THRESHOLD = 0.9
INDIVIDUAL_THRESHOLD = 0.7
predict_lock = Lock()
@suppress_callable_to_sentry(Exception, return_value=False)
def is_nsfw(image):
if image.mode == "RGBA":
image = image.convert("RGB")
image = image.resize((IMAGE_DIM, IMAGE_DIM))
image = np.array(image) / 255.0
image = np.expand_dims(image, axis=0)
with predict_lock:
preds = model.predict(image)[0]
categories = ["drawings", "hentai", "neutral", "porn", "sexy"]
probabilities = {cat: float(pred) for cat, pred in zip(categories, preds)}
individual_nsfw_prob = max(probabilities["porn"], probabilities["hentai"], probabilities["sexy"])
total_nsfw_prob = probabilities["porn"] + probabilities["hentai"] + probabilities["sexy"]
return (individual_nsfw_prob > INDIVIDUAL_THRESHOLD) or (total_nsfw_prob > TOTAL_THRESHOLD)
Это работает из python shell и django shell, но застревает в predict part в uwsgi и в celery, у кого-нибудь есть какие-либо идеи, почему это может происходить?
Я поставил кучу точек останова, и проблема в самом предсказании, в shell оно возвращается через ~100 мс, но зависает в uwsgi и celery более чем на 10 минут (дольше не пробовал, так как думаю, очевидно, что оно не вернется)
Пробовал с блокировкой и без нее, результат тот же