Model.predict hangs in celery/uwsgi
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from apps.common.utils.error_handling import suppress_callable_to_sentry
from django.conf import settings
from threading import Lock
MODEL_PATH = settings.BASE_DIR / "apps/core/utils/nsfw_detector/nsfw.299x299.h5"
model = tf.keras.models.load_model(MODEL_PATH, custom_objects={"KerasLayer": hub.KerasLayer}, compile=False)
IMAGE_DIM = 299
TOTAL_THRESHOLD = 0.9
INDIVIDUAL_THRESHOLD = 0.7
predict_lock = Lock()
@suppress_callable_to_sentry(Exception, return_value=False)
def is_nsfw(image):
if image.mode == "RGBA":
image = image.convert("RGB")
image = image.resize((IMAGE_DIM, IMAGE_DIM))
image = np.array(image) / 255.0
image = np.expand_dims(image, axis=0)
with predict_lock:
preds = model.predict(image)[0]
categories = ["drawings", "hentai", "neutral", "porn", "sexy"]
probabilities = {cat: float(pred) for cat, pred in zip(categories, preds)}
individual_nsfw_prob = max(probabilities["porn"], probabilities["hentai"], probabilities["sexy"])
total_nsfw_prob = probabilities["porn"] + probabilities["hentai"] + probabilities["sexy"]
return (individual_nsfw_prob > INDIVIDUAL_THRESHOLD) or (total_nsfw_prob > TOTAL_THRESHOLD)
This works from python shell and django shell but stucks at predict part in uwsgi and in celery, anyone got any idea why that might be happening?
I put a bunch of breakpoints and the problem is at the prediction itself, in shell it returns in ~100ms but it hangs in uwsgi and celery for 10+ minutes (didn't try for longer as I think it is obvious it won't return)
Tried it with and without the lock, same result