import itertools
import math
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, Expression, Func, Value, When
from django.db.models.fields import (
BooleanField,
CharField,
DateTimeField,
Field,
IntegerField,
UUIDField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet
from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
[docs]class Lookup(Expression):
lookup_name = None
prepare_rhs = True
can_use_none_as_rhs = False
def __init__(self, lhs, rhs):
self.lhs, self.rhs = lhs, rhs
self.rhs = self.get_prep_lookup()
self.lhs = self.get_prep_lhs()
if hasattr(self.lhs, "get_bilateral_transforms"):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
if bilateral_transforms:
# Warn the user as soon as possible if they are trying to apply
# a bilateral transformation on a nested QuerySet: that won't work.
from django.db.models.sql.query import Query # avoid circular import
if isinstance(rhs, Query):
raise NotImplementedError(
"Bilateral transformations on nested querysets are not implemented."
)
self.bilateral_transforms = bilateral_transforms
def apply_bilateral_transforms(self, value):
for transform in self.bilateral_transforms:
value = transform(value)
return value
def __repr__(self):
return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
def batch_process_rhs(self, compiler, connection, rhs=None):
if rhs is None:
rhs = self.rhs
if self.bilateral_transforms:
sqls, sqls_params = [], []
for p in rhs:
value = Value(p, output_field=self.lhs.output_field)
value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
sql, sql_params = compiler.compile(value)
sqls.append(sql)
sqls_params.extend(sql_params)
else:
_, params = self.get_db_prep_lookup(rhs, connection)
sqls, sqls_params = ["%s"] * len(params), params
return sqls, sqls_params
def get_source_expressions(self):
if self.rhs_is_direct_value():
return [self.lhs]
return [self.lhs, self.rhs]
def set_source_expressions(self, new_exprs):
if len(new_exprs) == 1:
self.lhs = new_exprs[0]
else:
self.lhs, self.rhs = new_exprs
def get_prep_lookup(self):
if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
return self.rhs
if hasattr(self.lhs, "output_field"):
if hasattr(self.lhs.output_field, "get_prep_value"):
return self.lhs.output_field.get_prep_value(self.rhs)
elif self.rhs_is_direct_value():
return Value(self.rhs)
return self.rhs
def get_prep_lhs(self):
if hasattr(self.lhs, "resolve_expression"):
return self.lhs
return Value(self.lhs)
def get_db_prep_lookup(self, value, connection):
return ("%s", [value])
[docs] def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs
if hasattr(lhs, "resolve_expression"):
lhs = lhs.resolve_expression(compiler.query)
sql, params = compiler.compile(lhs)
if isinstance(lhs, Lookup):
# Wrapped in parentheses to respect operator precedence.
sql = f"({sql})"
return sql, params
[docs] def process_rhs(self, compiler, connection):
value = self.rhs
if self.bilateral_transforms:
if self.rhs_is_direct_value():
# Do not call get_db_prep_lookup here as the value will be
# transformed before being used for lookup
value = Value(value, output_field=self.lhs.output_field)
value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
if hasattr(value, "as_sql"):
sql, params = compiler.compile(value)
# Ensure expression is wrapped in parentheses to respect operator
# precedence but avoid double wrapping as it can be misinterpreted
# on some backends (e.g. subqueries on SQLite).
if sql and sql[0] != "(":
sql = "(%s)" % sql
return sql, params
else:
return self.get_db_prep_lookup(value, connection)
def rhs_is_direct_value(self):
return not hasattr(self.rhs, "as_sql")
def get_group_by_cols(self):
cols = []
for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols())
return cols
def as_oracle(self, compiler, connection):
# Oracle doesn't allow EXISTS() and filters to be compared to another
# expression unless they're wrapped in a CASE WHEN.
wrapped = False
exprs = []
for expr in (self.lhs, self.rhs):
if connection.ops.conditional_expression_supported_in_where_clause(expr):
expr = Case(When(expr, then=True), default=False)
wrapped = True
exprs.append(expr)
lookup = type(self)(*exprs) if wrapped else self
return lookup.as_sql(compiler, connection)
@cached_property
def output_field(self):
return BooleanField()
@property
def identity(self):
return self.__class__, self.lhs, self.rhs
def __eq__(self, other):
if not isinstance(other, Lookup):
return NotImplemented
return self.identity == other.identity
def __hash__(self):
return hash(make_hashable(self.identity))
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
c = self.copy()
c.is_summary = summarize
c.lhs = self.lhs.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
if hasattr(self.rhs, "resolve_expression"):
c.rhs = self.rhs.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
return c
def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params
class BuiltinLookup(Lookup):
def process_lhs(self, compiler, connection, lhs=None):
lhs_sql, params = super().process_lhs(compiler, connection, lhs)
field_internal_type = self.lhs.output_field.get_internal_type()
db_type = self.lhs.output_field.db_type(connection=connection)
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
lhs_sql = (
connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
)
return lhs_sql, list(params)
def as_sql(self, compiler, connection):
lhs_sql, params = self.process_lhs(compiler, connection)
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
rhs_sql = self.get_rhs_op(connection, rhs_sql)
return "%s %s" % (lhs_sql, rhs_sql), params
def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
class FieldGetDbPrepValueMixin:
"""
Some lookups require Field.get_db_prep_value() to be called on their
inputs.
"""
get_db_prep_lookup_value_is_iterable = False
def get_db_prep_lookup(self, value, connection):
# For relational fields, use the 'target_field' attribute of the
# output_field.
field = getattr(self.lhs.output_field, "target_field", None)
get_db_prep_value = (
getattr(field, "get_db_prep_value", None)
or self.lhs.output_field.get_db_prep_value
)
return (
"%s",
[get_db_prep_value(v, connection, prepared=True) for v in value]
if self.get_db_prep_lookup_value_is_iterable
else [get_db_prep_value(value, connection, prepared=True)],
)
class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
"""
Some lookups require Field.get_db_prep_value() to be called on each value
in an iterable.
"""
get_db_prep_lookup_value_is_iterable = True
def get_prep_lookup(self):
if hasattr(self.rhs, "resolve_expression"):
return self.rhs
prepared_values = []
for rhs_value in self.rhs:
if hasattr(rhs_value, "resolve_expression"):
# An expression will be handled by the database but can coexist
# alongside real values.
pass
elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
prepared_values.append(rhs_value)
return prepared_values
def process_rhs(self, compiler, connection):
if self.rhs_is_direct_value():
# rhs should be an iterable of values. Use batch_process_rhs()
# to prepare/transform those values.
return self.batch_process_rhs(compiler, connection)
else:
return super().process_rhs(compiler, connection)
def resolve_expression_parameter(self, compiler, connection, sql, param):
params = [param]
if hasattr(param, "resolve_expression"):
param = param.resolve_expression(compiler.query)
if hasattr(param, "as_sql"):
sql, params = compiler.compile(param)
return sql, params
def batch_process_rhs(self, compiler, connection, rhs=None):
pre_processed = super().batch_process_rhs(compiler, connection, rhs)
# The params list may contain expressions which compile to a
# sql/param pair. Zip them to get sql and param pairs that refer to the
# same argument and attempt to replace them with the result of
# compiling the param step.
sql, params = zip(
*(
self.resolve_expression_parameter(compiler, connection, sql, param)
for sql, param in zip(*pre_processed)
)
)
params = itertools.chain.from_iterable(params)
return sql, tuple(params)
class PostgresOperatorLookup(Lookup):
"""Lookup defined by operators on PostgreSQL."""
postgres_operator = None
def as_postgresql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return "%s %s %s" % (lhs, self.postgres_operator, rhs), params
@Field.register_lookup
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "exact"
def get_prep_lookup(self):
from django.db.models.sql.query import Query # avoid circular import
if isinstance(self.rhs, Query):
if self.rhs.has_limit_one():
if not self.rhs.has_select_fields:
self.rhs.clear_select_clause()
self.rhs.add_fields(["pk"])
else:
raise ValueError(
"The QuerySet value for an exact lookup must be limited to "
"one result using slicing."
)
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
# Avoid comparison against direct rhs if lhs is a boolean value. That
# turns "boolfield__exact=True" into "WHERE boolean_field" instead of
# "WHERE boolean_field = True" when allowed.
if (
isinstance(self.rhs, bool)
and getattr(self.lhs, "conditional", False)
and connection.ops.conditional_expression_supported_in_where_clause(
self.lhs
)
):
lhs_sql, params = self.process_lhs(compiler, connection)
template = "%s" if self.rhs else "NOT %s"
return template % lhs_sql, params
return super().as_sql(compiler, connection)
@Field.register_lookup
class IExact(BuiltinLookup):
lookup_name = "iexact"
prepare_rhs = False
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if params:
params[0] = connection.ops.prep_for_iexact_query(params[0])
return rhs, params
@Field.register_lookup
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "gt"
@Field.register_lookup
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "gte"
@Field.register_lookup
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "lt"
@Field.register_lookup
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "lte"
class IntegerFieldFloatRounding:
"""
Allow floats to work as query values for IntegerField. Without this, the
decimal portion of the float would always be discarded.
"""
def get_prep_lookup(self):
if isinstance(self.rhs, float):
self.rhs = math.ceil(self.rhs)
return super().get_prep_lookup()
@IntegerField.register_lookup
class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual):
pass
@IntegerField.register_lookup
class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
pass
@Field.register_lookup
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = "in"
def get_prep_lookup(self):
from django.db.models.sql.query import Query # avoid circular import
if isinstance(self.rhs, Query):
self.rhs.clear_ordering(clear_default=True)
if not self.rhs.has_select_fields:
self.rhs.clear_select_clause()
self.rhs.add_fields(["pk"])
return super().get_prep_lookup()
def process_rhs(self, compiler, connection):
db_rhs = getattr(self.rhs, "_db", None)
if db_rhs is not None and db_rhs != connection.alias:
raise ValueError(
"Subqueries aren't allowed across different databases. Force "
"the inner query to be evaluated using `list(inner_query)`."
)
if self.rhs_is_direct_value():
# Remove None from the list as NULL is never equal to anything.
try:
rhs = OrderedSet(self.rhs)
rhs.discard(None)
except TypeError: # Unhashable items in self.rhs
rhs = [r for r in self.rhs if r is not None]
if not rhs:
raise EmptyResultSet
# rhs should be an iterable; use batch_process_rhs() to
# prepare/transform those values.
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
placeholder = "(" + ", ".join(sqls) + ")"
return (placeholder, sqls_params)
return super().process_rhs(compiler, connection)
def get_rhs_op(self, connection, rhs):
return "IN %s" % rhs
def as_sql(self, compiler, connection):
max_in_list_size = connection.ops.max_in_list_size()
if (
self.rhs_is_direct_value()
and max_in_list_size
and len(self.rhs) > max_in_list_size
):
return self.split_parameter_list_as_sql(compiler, connection)
return super().as_sql(compiler, connection)
def split_parameter_list_as_sql(self, compiler, connection):
# This is a special case for databases which limit the number of
# elements which can appear in an 'IN' clause.
max_in_list_size = connection.ops.max_in_list_size()
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
in_clause_elements = ["("]
params = []
for offset in range(0, len(rhs_params), max_in_list_size):
if offset > 0:
in_clause_elements.append(" OR ")
in_clause_elements.append("%s IN (" % lhs)
params.extend(lhs_params)
sqls = rhs[offset : offset + max_in_list_size]
sqls_params = rhs_params[offset : offset + max_in_list_size]
param_group = ", ".join(sqls)
in_clause_elements.append(param_group)
in_clause_elements.append(")")
params.extend(sqls_params)
in_clause_elements.append(")")
return "".join(in_clause_elements), params
class PatternLookup(BuiltinLookup):
param_pattern = "%%%s%%"
prepare_rhs = False
def get_rhs_op(self, connection, rhs):
# Assume we are in startswith. We need to produce SQL like:
# col LIKE %s, ['thevalue%']
# For python values we can (and should) do that directly in Python,
# but if the value is for example reference to other column, then
# we need to add the % pattern match to the lookup by something like
# col LIKE othercol || '%%'
# So, for Python values we don't need any special pattern, but for
# SQL reference values or SQL transformations we need the correct
# pattern added.
if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
pattern = connection.pattern_ops[self.lookup_name].format(
connection.pattern_esc
)
return pattern.format(rhs)
else:
return super().get_rhs_op(connection, rhs)
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
params[0] = self.param_pattern % connection.ops.prep_for_like_query(
params[0]
)
return rhs, params
@Field.register_lookup
class Contains(PatternLookup):
lookup_name = "contains"
@Field.register_lookup
class IContains(Contains):
lookup_name = "icontains"
@Field.register_lookup
class StartsWith(PatternLookup):
lookup_name = "startswith"
param_pattern = "%s%%"
@Field.register_lookup
class IStartsWith(StartsWith):
lookup_name = "istartswith"
@Field.register_lookup
class EndsWith(PatternLookup):
lookup_name = "endswith"
param_pattern = "%%%s"
@Field.register_lookup
class IEndsWith(EndsWith):
lookup_name = "iendswith"
@Field.register_lookup
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = "range"
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
@Field.register_lookup
class IsNull(BuiltinLookup):
lookup_name = "isnull"
prepare_rhs = False
def as_sql(self, compiler, connection):
if not isinstance(self.rhs, bool):
raise ValueError(
"The QuerySet value for an isnull lookup must be True or False."
)
sql, params = self.process_lhs(compiler, connection)
if self.rhs:
return "%s IS NULL" % sql, params
else:
return "%s IS NOT NULL" % sql, params
@Field.register_lookup
class Regex(BuiltinLookup):
lookup_name = "regex"
prepare_rhs = False
def as_sql(self, compiler, connection):
if self.lookup_name in connection.operators:
return super().as_sql(compiler, connection)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
sql_template = connection.ops.regex_lookup(self.lookup_name)
return sql_template % (lhs, rhs), lhs_params + rhs_params
@Field.register_lookup
class IRegex(Regex):
lookup_name = "iregex"
class YearLookup(Lookup):
def year_lookup_bounds(self, connection, year):
from django.db.models.functions import ExtractIsoYear
iso_year = isinstance(self.lhs, ExtractIsoYear)
output_field = self.lhs.lhs.output_field
if isinstance(output_field, DateTimeField):
bounds = connection.ops.year_lookup_bounds_for_datetime_field(
year,
iso_year=iso_year,
)
else:
bounds = connection.ops.year_lookup_bounds_for_date_field(
year,
iso_year=iso_year,
)
return bounds
def as_sql(self, compiler, connection):
# Avoid the extract operation if the rhs is a direct value to allow
# indexes to be used.
if self.rhs_is_direct_value():
# Skip the extract part by directly using the originating field,
# that is self.lhs.lhs.
lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
rhs_sql, _ = self.process_rhs(compiler, connection)
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, self.rhs)
params.extend(self.get_bound_params(start, finish))
return "%s %s" % (lhs_sql, rhs_sql), params
return super().as_sql(compiler, connection)
def get_direct_rhs_sql(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
def get_bound_params(self, start, finish):
raise NotImplementedError(
"subclasses of YearLookup must provide a get_bound_params() method"
)
class YearExact(YearLookup, Exact):
def get_direct_rhs_sql(self, connection, rhs):
return "BETWEEN %s AND %s"
def get_bound_params(self, start, finish):
return (start, finish)
class YearGt(YearLookup, GreaterThan):
def get_bound_params(self, start, finish):
return (finish,)
class YearGte(YearLookup, GreaterThanOrEqual):
def get_bound_params(self, start, finish):
return (start,)
class YearLt(YearLookup, LessThan):
def get_bound_params(self, start, finish):
return (start,)
class YearLte(YearLookup, LessThanOrEqual):
def get_bound_params(self, start, finish):
return (finish,)
class UUIDTextMixin:
"""
Strip hyphens from a value when filtering a UUIDField on backends without
a native datatype for UUID.
"""
def process_rhs(self, qn, connection):
if not connection.features.has_native_uuid_field:
from django.db.models.functions import Replace
if self.rhs_is_direct_value():
self.rhs = Value(self.rhs)
self.rhs = Replace(
self.rhs, Value("-"), Value(""), output_field=CharField()
)
rhs, params = super().process_rhs(qn, connection)
return rhs, params
@UUIDField.register_lookup
class UUIDIExact(UUIDTextMixin, IExact):
pass
@UUIDField.register_lookup
class UUIDContains(UUIDTextMixin, Contains):
pass
@UUIDField.register_lookup
class UUIDIContains(UUIDTextMixin, IContains):
pass
@UUIDField.register_lookup
class UUIDStartsWith(UUIDTextMixin, StartsWith):
pass
@UUIDField.register_lookup
class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
pass
@UUIDField.register_lookup
class UUIDEndsWith(UUIDTextMixin, EndsWith):
pass
@UUIDField.register_lookup
class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
pass