On every api request list in add_where method get appended it is not resetting

I am trying to create custom manager in django rest framework that interacts with athena. Although I am facing strange issue that everytime I hit the endpoint with filters it gets appended in list where_conditions=[] on each request. It should not happen that actually makes the query that converts into SQL invalid. i want to reset the where_conditions list on each request.

filter_class/base.py

from apps.analytics.models.filter.fields.filter import FilterField
from apps.analytics.models.filter.fields.ordering import OrderingField
from apps.analytics.models.filter.fields.search import SearchField
from apps.analytics.models.manager.query.utils import ModelQueryUtils


class BaseFilterClass:
    _ordering_filter_query_key = "ordering"
    _search_filter_query_key = "search"

    def __init__(self, request):
        self._request = request

    @property
    def filter_params(self):
        if getattr(self.Meta, "filter_fields", None):
            req_query_params = {**self._request.query_params}
            for i in req_query_params:
                if isinstance(req_query_params[i], list):
                    req_query_params[i] = req_query_params[i][0]
            if req_query_params and req_query_params.get("search"):
                del req_query_params["search"]
            req_query_params_key = set(req_query_params.keys())
            valid_filter_params = self.Meta.filter_fields.declared_filter_params.intersection(req_query_params_key)
            params = {key: self.Meta.filter_fields.get_processed_value(key.split('__')[-1],
                                                                       req_query_params[key]) if "__" in key else
            req_query_params[key] for key in valid_filter_params}
            return ModelQueryUtils.FilterInput(params, ModelQueryUtils.FilterInput.JoinOperator.AND)

    @property
    def ordering_params(self):
        if getattr(self.Meta, "ordering_fields", None):
            req_query_params_ordering_params = self._request.query_params and self._request.query_params.get(
                self._ordering_filter_query_key)
            if req_query_params_ordering_params:
                declared_ordering_params = self.Meta.ordering_fields.declared_ordering_params
                return ModelQueryUtils.OrderInput(
                    [str(param) for param in req_query_params_ordering_params.split(",") if
                     ((param.startswith("-") and param[1:]) or param) in declared_ordering_params])

    @property
    def search_params(self):
        if getattr(self.Meta, "search_fields", None):
            req_query_params_search_param = self._request.query_params and self._request.query_params.get(
                self._search_filter_query_key)
            params = {}
            if req_query_params_search_param:
                for key in self.Meta.search_fields.declared_search_params:
                    params[key] = req_query_params_search_param
            return ModelQueryUtils.SearchInput(params, ModelQueryUtils.FilterInput.JoinOperator.OR)

    class Meta:
        filter_fields = FilterField([])
        search_fields = SearchField([])
        ordering_fields = OrderingField([])


manager/query/base.py

class BaseModelQuery:
    __slots__ = ['_model']
    # _where_conditions = None
    _order_conditions = None
    _offset: int = None
    _limit: int = None
    _full_query_mtd_name: str = None

    def __init__(self, model):
        self._model = model
        self._where_conditions = []
        self._order_conditions = []
        self._offset = None
        self._limit = None
        self._full_query_mtd_name = None

    @property
    def count_sql(self):
        if self._full_query_mtd_name:
            return getattr(self, self._full_query_mtd_name)()["count"]
        return f"""
            SELECT
                Count(*)
            FROM
                {self._model.table_name}
            {self._where_sql}
        """

    @property
    def full_sql(self):
        if self._full_query_mtd_name:
            return getattr(self, self._full_query_mtd_name)()["full"]
        return f"""
            {self._base_sql}
            {self._where_sql}
            {self._order_by_sql}
            {self._offset_sql}
            {self._limit_sql}
        """

    @property
    def _base_sql(self):
        return f"""
            SELECT
                {','.join(field_name for field_name in self._model.fields.keys())}
            FROM
                {self._model.table_name}
        """

    @property
    def _where_sql(self):
        if self._where_conditions:
            statements = [item.get_sql_statement(self._model) for item in self._where_conditions]
            print(statements)
            statements = [f"({item})" for item in statements if item]
            if statements:
                return f"""
                WHERE
                    {" AND ".join(statements)}
                """
        return ""

    @property
    def _order_by_sql(self):
        if self._order_conditions:
            statements = [item.get_sql_statement() for item in self._order_conditions]
            statements = [item for item in statements if item]
            if statements:
                return f"""
            ORDER BY
                {", ".join(statements)}
            """
        return ""

    @property
    def _offset_sql(self):
        if self._offset is not None:
            return f"""
            OFFSET {self._offset}
            """
        return ""

    @property
    def _limit_sql(self):
        if self._limit is not None:
            return f"""
            LIMIT {self._limit}
            """
        return ""

    def add_where(self, condition):
        self._where_conditions.append(condition)

    def add_order(self, condition):
        self._order_conditions.append(condition)

    def set_offset(self, offset: int):
        self._offset = offset

    def set_limit(self, limit: int):
        self._limit = limit

    def set_full_query_mtd_name(self, name: str):
        self._full_query_mtd_name = name

manager/query/utils.py

from apps.analytics.models.fields import ModelField
from apps.analytics.models.filter.lookups import FilterFieldLookup


class ModelQueryUtils:
    class FilterInput:
        _filter_lookup_operator_mapping = {
            FilterFieldLookup.Definition.NOT_EQUAL: "<>",
            FilterFieldLookup.Definition.GREATER_THAN_EQUAL_TO: ">=",
            FilterFieldLookup.Definition.LESS_THAN_EQUAL_TO: "<=",
            FilterFieldLookup.Definition.GREATER_THAN: ">",
            FilterFieldLookup.Definition.LESS_THAN: "<",
            FilterFieldLookup.Definition.IS_IN: "IN",
            FilterFieldLookup.Definition.CONTAINS: "LIKE",
            FilterFieldLookup.Definition.ICONTAINS: "LIKE"
        }

        class JoinOperator:
            OR = "OR"
            AND = "AND"

        def __init__(self, filter_params: dict, operator: str):
            self.filter_params = filter_params
            self.operator = operator

        def get_sql_statement(self, model):
            list_conditions = []
            for key, value in self.filter_params.items():
                field, lookup = key.split("__") if "__" in key else (key, None)
                model_field: ModelField = model.fields.get(field)
                if model_field:
                    value = model_field.get_cast_statement_for_value(value)
                    list_conditions.append(
                        f"{field} {self._filter_lookup_operator_mapping.get(lookup) or '='} {value}")
            return list_conditions and f" {self.operator} ".join(list_conditions)

    class SearchInput(FilterInput):

        def get_sql_statement(self, model):
            list_conditions = []
            for key, value in self.filter_params.items():
                field, lookup = key.split("__") if "__" in key else (key, None)
                model_field: ModelField = model.fields.get(field)
                if model_field:
                    value = model_field.get_cast_statement_for_value(value)
                    value = value.strip("'")
                    list_conditions.append(
                        f"LOWER({field}) {self._filter_lookup_operator_mapping.get(lookup) or '='} LOWER('%{value}%')")
            return list_conditions and f" {self.operator} ".join(list_conditions)

    class OrderInput:
        def __init__(self, order_params: list[str]):
            self.order_params = order_params

        def get_sql_statement(self):
            return ", ".join("{}{}".format(param.lstrip("-"), " DESC" if param.startswith("-") else "") for param in
                             self.order_params)

manager/queryset/base.py

from apps.analytics.models.manager.query.base import BaseModelQuery
from apps.analytics.models.manager.query.utils import ModelQueryUtils
from apps.utils.custom_exeption.custom_exceptions import CustomException
from apps.utils.helpers.externals.aws.athena.client import AWSAthenaClient

class BaseModelQueryset:
    __slots__ = ['_query', '_annotations']

    def __init__(self, query: BaseModelQuery):
        self._query = query
        self._annotations = {}
    def __iter__(self):
        query_client = AWSAthenaClient()

        print(" ---- FULL QUERY -----------")
        print(self._query.full_sql)
        print("--- working ---")
        response, _ = query_client.execute_query(self._query.full_sql)

        if _:
            raise CustomException(
                data={
                    "message": _,
                    "errors": ["analytic_query_error"]
                }
            )
        for item in response:
            item.update(self._annotations)
            yield item

    
    def annotate(self, **annotations):
        """Custom method to handle annotations like Sum or other aggregations."""
        for alias, annotation in annotations.items():
            if isinstance(annotation, str) and annotation.lower() == 'sum':
                # Construct and execute a sum query
                query_client = AWSAthenaClient()
                sum_query = f"SELECT SUM(event_count) AS {alias} FROM {self._query.table_name}"
                response, error = query_client.execute_query(sum_query)

                if error:
                    raise CustomException(
                        data={"message": error, "errors": ["annotation_query_error"]}
                    )

                # Store the result in _annotations
                self._annotations[alias] = response[0].get(alias, 0)
        
        return self
    def count(self):
        query_client = AWSAthenaClient()
        response, _ = query_client.execute_query(self._query.count_sql)
        if _ or not (response):
            raise CustomException(
                data={
                    "message": _,
                    "errors": ["analytic_query_error"]
                }
            )
        return int(response[0]['_col0'])

    def filter(self, filter_input=None, **kwargs):
        self._query._where_conditions = []
        print(" -- Filter ---")
        print(filter_input)
        if filter_input:
            self._query.add_where(filter_input)
        else:
            print(" --- I am here ----")
            self._query.add_where(ModelQueryUtils.FilterInput(kwargs, ModelQueryUtils.FilterInput.JoinOperator.AND))
        return self

    def order_by(self, *args, order_input=None):
        if order_input:
            self._query.add_order(order_input)
        else:
            self._query.add_order(ModelQueryUtils.OrderInput(args))
        return self

    def __getitem__(self, key):
        if isinstance(key, slice):
            if key.start:
                self._query.set_offset(key.start)
            if key.stop:
                self._query.set_limit(key.stop - key.start)
        else:
            self._query.set_offset(key)
            self._query.set_limit(1)
        return self

manager/init.py


class LineIntrusionAnalyticsMananger(BaseModelManager):
    def __init__(self):
        super().__init__(queryset_class=queryset.LineIntrusionAnalyticsQueryset)

Methods like .add_where(…) don't make much sense on a QuerySet. Django's QuerySets are essentially immutable. That means that if you use MyModel.objects.filter(some_filter1).filter(some_filter2), each call to the .filter(…) produces a new QuerySet that is a modified copy of the previous one.

This might look like a detail, and an inefficient one, since we each time clone the QuerySet, but it is an important one. Indeed, in Django if you make for example a ListView:

class MyListView(ListView):
    queryset = MyModel.objects.all()

then you set a QuerySet to the queryset attribute, Django each time calls .all() to clone the QuerySet, such that it does not uses the cache of the "parent" QuerySet, and thus forces reevaluation. For the same reason if you would .filter(…) the .queryset, it has no implications on the original one.

You thus should not define methods that change state: each method should create a new QuerySet. Failing to do so might eventually result in a "global state" where you thus alter the "parent" queryset.

Вернуться на верх