import datetime import json from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange, Range from django.contrib.postgres import forms, lookups from django.db import models from django.db.models.lookups import PostgresOperatorLookup from .utils import AttributeSetter __all__ = [ "RangeField", "IntegerRangeField", "BigIntegerRangeField", "DecimalRangeField", "DateTimeRangeField", "DateRangeField", "RangeBoundary", "RangeOperators", ] class RangeBoundary(models.Expression): """A class that represents range boundaries.""" def __init__(self, inclusive_lower=True, inclusive_upper=False): self.lower = "[" if inclusive_lower else "(" self.upper = "]" if inclusive_upper else ")" def as_sql(self, compiler, connection): return "'%s%s'" % (self.lower, self.upper), [] class RangeOperators: # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE EQUAL = "=" NOT_EQUAL = "<>" CONTAINS = "@>" CONTAINED_BY = "<@" OVERLAPS = "&&" FULLY_LT = "<<" FULLY_GT = ">>" NOT_LT = "&>" NOT_GT = "&<" ADJACENT_TO = "-|-" class RangeField(models.Field): empty_strings_allowed = False def __init__(self, *args, **kwargs): if "default_bounds" in kwargs: raise TypeError( f"Cannot use 'default_bounds' with {self.__class__.__name__}." ) # Initializing base_field here ensures that its model matches the model # for self. if hasattr(self, "base_field"): self.base_field = self.base_field() super().__init__(*args, **kwargs) @property def model(self): try: return self.__dict__["model"] except KeyError: raise AttributeError( "'%s' object has no attribute 'model'" % self.__class__.__name__ ) @model.setter def model(self, model): self.__dict__["model"] = model self.base_field.model = model @classmethod def _choices_is_value(cls, value): return isinstance(value, (list, tuple)) or super()._choices_is_value(value) def get_placeholder(self, value, compiler, connection): return "%s::{}".format(self.db_type(connection)) def get_prep_value(self, value): if value is None: return None elif isinstance(value, Range): return value elif isinstance(value, (list, tuple)): return self.range_type(value[0], value[1]) return value def to_python(self, value): if isinstance(value, str): # Assume we're deserializing vals = json.loads(value) for end in ("lower", "upper"): if end in vals: vals[end] = self.base_field.to_python(vals[end]) value = self.range_type(**vals) elif isinstance(value, (list, tuple)): value = self.range_type(value[0], value[1]) return value def set_attributes_from_name(self, name): super().set_attributes_from_name(name) self.base_field.set_attributes_from_name(name) def value_to_string(self, obj): value = self.value_from_object(obj) if value is None: return None if value.isempty: return json.dumps({"empty": True}) base_field = self.base_field result = {"bounds": value._bounds} for end in ("lower", "upper"): val = getattr(value, end) if val is None: result[end] = None else: obj = AttributeSetter(base_field.attname, val) result[end] = base_field.value_to_string(obj) return json.dumps(result) def formfield(self, **kwargs): kwargs.setdefault("form_class", self.form_field) return super().formfield(**kwargs) CANONICAL_RANGE_BOUNDS = "[)" class ContinuousRangeField(RangeField): """ Continuous range field. It allows specifying default bounds for list and tuple inputs. """ def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs): if default_bounds not in ("[)", "(]", "()", "[]"): raise ValueError("default_bounds must be one of '[)', '(]', '()', or '[]'.") self.default_bounds = default_bounds super().__init__(*args, **kwargs) def get_prep_value(self, value): if isinstance(value, (list, tuple)): return self.range_type(value[0], value[1], self.default_bounds) return super().get_prep_value(value) def formfield(self, **kwargs): kwargs.setdefault("default_bounds", self.default_bounds) return super().formfield(**kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS: kwargs["default_bounds"] = self.default_bounds return name, path, args, kwargs class IntegerRangeField(RangeField): base_field = models.IntegerField range_type = NumericRange form_field = forms.IntegerRangeField def db_type(self, connection): return "int4range" class BigIntegerRangeField(RangeField): base_field = models.BigIntegerField range_type = NumericRange form_field = forms.IntegerRangeField def db_type(self, connection): return "int8range" class DecimalRangeField(ContinuousRangeField): base_field = models.DecimalField range_type = NumericRange form_field = forms.DecimalRangeField def db_type(self, connection): return "numrange" class DateTimeRangeField(ContinuousRangeField): base_field = models.DateTimeField range_type = DateTimeTZRange form_field = forms.DateTimeRangeField def db_type(self, connection): return "tstzrange" class DateRangeField(RangeField): base_field = models.DateField range_type = DateRange form_field = forms.DateRangeField def db_type(self, connection): return "daterange" RangeField.register_lookup(lookups.DataContains) RangeField.register_lookup(lookups.ContainedBy) RangeField.register_lookup(lookups.Overlap) class DateTimeRangeContains(PostgresOperatorLookup): """ Lookup for Date/DateTimeRange containment to cast the rhs to the correct type. """ lookup_name = "contains" postgres_operator = RangeOperators.CONTAINS def process_rhs(self, compiler, connection): # Transform rhs value for db lookup. if isinstance(self.rhs, datetime.date): value = models.Value(self.rhs) self.rhs = value.resolve_expression(compiler.query) return super().process_rhs(compiler, connection) def as_postgresql(self, compiler, connection): sql, params = super().as_postgresql(compiler, connection) # Cast the rhs if needed. cast_sql = "" if ( isinstance(self.rhs, models.Expression) and self.rhs._output_field_or_none and # Skip cast if rhs has a matching range type. not isinstance( self.rhs._output_field_or_none, self.lhs.output_field.__class__ ) ): cast_internal_type = self.lhs.output_field.base_field.get_internal_type() cast_sql = "::{}".format(connection.data_types.get(cast_internal_type)) return "%s%s" % (sql, cast_sql), params DateRangeField.register_lookup(DateTimeRangeContains) DateTimeRangeField.register_lookup(DateTimeRangeContains) class RangeContainedBy(PostgresOperatorLookup): lookup_name = "contained_by" type_mapping = { "smallint": "int4range", "integer": "int4range", "bigint": "int8range", "double precision": "numrange", "numeric": "numrange", "date": "daterange", "timestamp with time zone": "tstzrange", } postgres_operator = RangeOperators.CONTAINED_BY def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) # Ignore precision for DecimalFields. db_type = self.lhs.output_field.cast_db_type(connection).split("(")[0] cast_type = self.type_mapping[db_type] return "%s::%s" % (rhs, cast_type), rhs_params def process_lhs(self, compiler, connection): lhs, lhs_params = super().process_lhs(compiler, connection) if isinstance(self.lhs.output_field, models.FloatField): lhs = "%s::numeric" % lhs elif isinstance(self.lhs.output_field, models.SmallIntegerField): lhs = "%s::integer" % lhs return lhs, lhs_params def get_prep_lookup(self): return RangeField().get_prep_value(self.rhs) models.DateField.register_lookup(RangeContainedBy) models.DateTimeField.register_lookup(RangeContainedBy) models.IntegerField.register_lookup(RangeContainedBy) models.FloatField.register_lookup(RangeContainedBy) models.DecimalField.register_lookup(RangeContainedBy) @RangeField.register_lookup class FullyLessThan(PostgresOperatorLookup): lookup_name = "fully_lt" postgres_operator = RangeOperators.FULLY_LT @RangeField.register_lookup class FullGreaterThan(PostgresOperatorLookup): lookup_name = "fully_gt" postgres_operator = RangeOperators.FULLY_GT @RangeField.register_lookup class NotLessThan(PostgresOperatorLookup): lookup_name = "not_lt" postgres_operator = RangeOperators.NOT_LT @RangeField.register_lookup class NotGreaterThan(PostgresOperatorLookup): lookup_name = "not_gt" postgres_operator = RangeOperators.NOT_GT @RangeField.register_lookup class AdjacentToLookup(PostgresOperatorLookup): lookup_name = "adjacent_to" postgres_operator = RangeOperators.ADJACENT_TO @RangeField.register_lookup class RangeStartsWith(models.Transform): lookup_name = "startswith" function = "lower" @property def output_field(self): return self.lhs.output_field.base_field @RangeField.register_lookup class RangeEndsWith(models.Transform): lookup_name = "endswith" function = "upper" @property def output_field(self): return self.lhs.output_field.base_field @RangeField.register_lookup class IsEmpty(models.Transform): lookup_name = "isempty" function = "isempty" output_field = models.BooleanField() @RangeField.register_lookup class LowerInclusive(models.Transform): lookup_name = "lower_inc" function = "LOWER_INC" output_field = models.BooleanField() @RangeField.register_lookup class LowerInfinite(models.Transform): lookup_name = "lower_inf" function = "LOWER_INF" output_field = models.BooleanField() @RangeField.register_lookup class UpperInclusive(models.Transform): lookup_name = "upper_inc" function = "UPPER_INC" output_field = models.BooleanField() @RangeField.register_lookup class UpperInfinite(models.Transform): lookup_name = "upper_inf" function = "UPPER_INF" output_field = models.BooleanField()