import json from django import forms from django.core import checks, exceptions from django.db import NotSupportedError, connections, router from django.db.models import lookups from django.db.models.lookups import PostgresOperatorLookup, Transform from django.utils.translation import gettext_lazy as _ from . import Field from .mixins import CheckFieldDefaultMixin __all__ = ["JSONField"] class JSONField(CheckFieldDefaultMixin, Field): empty_strings_allowed = False description = _("A JSON object") default_error_messages = { "invalid": _("Value must be valid JSON."), } _default_hint = ("dict", "{}") def __init__( self, verbose_name=None, name=None, encoder=None, decoder=None, **kwargs, ): if encoder and not callable(encoder): raise ValueError("The encoder parameter must be a callable object.") if decoder and not callable(decoder): raise ValueError("The decoder parameter must be a callable object.") self.encoder = encoder self.decoder = decoder super().__init__(verbose_name, name, **kwargs) def check(self, **kwargs): errors = super().check(**kwargs) databases = kwargs.get("databases") or [] errors.extend(self._check_supported(databases)) return errors def _check_supported(self, databases): errors = [] for db in databases: if not router.allow_migrate_model(db, self.model): continue connection = connections[db] if ( self.model._meta.required_db_vendor and self.model._meta.required_db_vendor != connection.vendor ): continue if not ( "supports_json_field" in self.model._meta.required_db_features or connection.features.supports_json_field ): errors.append( checks.Error( "%s does not support JSONFields." % connection.display_name, obj=self.model, id="fields.E180", ) ) return errors def deconstruct(self): name, path, args, kwargs = super().deconstruct() if self.encoder is not None: kwargs["encoder"] = self.encoder if self.decoder is not None: kwargs["decoder"] = self.decoder return name, path, args, kwargs def from_db_value(self, value, expression, connection): if value is None: return value # Some backends (SQLite at least) extract non-string values in their # SQL datatypes. if isinstance(expression, KeyTransform) and not isinstance(value, str): return value try: return json.loads(value, cls=self.decoder) except json.JSONDecodeError: return value def get_internal_type(self): return "JSONField" def get_prep_value(self, value): if value is None: return value return json.dumps(value, cls=self.encoder) def get_transform(self, name): transform = super().get_transform(name) if transform: return transform return KeyTransformFactory(name) def validate(self, value, model_instance): super().validate(value, model_instance) try: json.dumps(value, cls=self.encoder) except TypeError: raise exceptions.ValidationError( self.error_messages["invalid"], code="invalid", params={"value": value}, ) def value_to_string(self, obj): return self.value_from_object(obj) def formfield(self, **kwargs): return super().formfield( **{ "form_class": forms.JSONField, "encoder": self.encoder, "decoder": self.decoder, **kwargs, } ) def compile_json_path(key_transforms, include_root=True): path = ["$"] if include_root else [] for key_transform in key_transforms: try: num = int(key_transform) except ValueError: # non-integer path.append(".") path.append(json.dumps(key_transform)) else: path.append("[%s]" % num) return "".join(path) class DataContains(PostgresOperatorLookup): lookup_name = "contains" postgres_operator = "@>" def as_sql(self, compiler, connection): if not connection.features.supports_json_field_contains: raise NotSupportedError( "contains lookup is not supported on this database backend." ) lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(lhs_params) + tuple(rhs_params) return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params class ContainedBy(PostgresOperatorLookup): lookup_name = "contained_by" postgres_operator = "<@" def as_sql(self, compiler, connection): if not connection.features.supports_json_field_contains: raise NotSupportedError( "contained_by lookup is not supported on this database backend." ) lhs, lhs_params = self.process_lhs(compiler, connection) rhs, rhs_params = self.process_rhs(compiler, connection) params = tuple(rhs_params) + tuple(lhs_params) return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params class HasKeyLookup(PostgresOperatorLookup): logical_operator = None def compile_json_path_final_key(self, key_transform): # Compile the final key without interpreting ints as array elements. return ".%s" % json.dumps(key_transform) def as_sql(self, compiler, connection, template=None): # Process JSON path from the left-hand side. if isinstance(self.lhs, KeyTransform): lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs( compiler, connection ) lhs_json_path = compile_json_path(lhs_key_transforms) else: lhs, lhs_params = self.process_lhs(compiler, connection) lhs_json_path = "$" sql = template % lhs # Process JSON path from the right-hand side. rhs = self.rhs rhs_params = [] if not isinstance(rhs, (list, tuple)): rhs = [rhs] for key in rhs: if isinstance(key, KeyTransform): *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection) else: rhs_key_transforms = [key] *rhs_key_transforms, final_key = rhs_key_transforms rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False) rhs_json_path += self.compile_json_path_final_key(final_key) rhs_params.append(lhs_json_path + rhs_json_path) # Add condition for each key. if self.logical_operator: sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params)) return sql, tuple(lhs_params) + tuple(rhs_params) def as_mysql(self, compiler, connection): return self.as_sql( compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)" ) def as_oracle(self, compiler, connection): sql, params = self.as_sql( compiler, connection, template="JSON_EXISTS(%s, '%%s')" ) # Add paths directly into SQL because path expressions cannot be passed # as bind variables on Oracle. return sql % tuple(params), [] def as_postgresql(self, compiler, connection): if isinstance(self.rhs, KeyTransform): *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection) for key in rhs_key_transforms[:-1]: self.lhs = KeyTransform(key, self.lhs) self.rhs = rhs_key_transforms[-1] return super().as_postgresql(compiler, connection) def as_sqlite(self, compiler, connection): return self.as_sql( compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL" ) class HasKey(HasKeyLookup): lookup_name = "has_key" postgres_operator = "?" prepare_rhs = False class HasKeys(HasKeyLookup): lookup_name = "has_keys" postgres_operator = "?&" logical_operator = " AND " def get_prep_lookup(self): return [str(item) for item in self.rhs] class HasAnyKeys(HasKeys): lookup_name = "has_any_keys" postgres_operator = "?|" logical_operator = " OR " class HasKeyOrArrayIndex(HasKey): def compile_json_path_final_key(self, key_transform): return compile_json_path([key_transform], include_root=False) class CaseInsensitiveMixin: """ Mixin to allow case-insensitive comparison of JSON values on MySQL. MySQL handles strings used in JSON context using the utf8mb4_bin collation. Because utf8mb4_bin is a binary collation, comparison of JSON values is case-sensitive. """ def process_lhs(self, compiler, connection): lhs, lhs_params = super().process_lhs(compiler, connection) if connection.vendor == "mysql": return "LOWER(%s)" % lhs, lhs_params return lhs, lhs_params def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) if connection.vendor == "mysql": return "LOWER(%s)" % rhs, rhs_params return rhs, rhs_params class JSONExact(lookups.Exact): can_use_none_as_rhs = True def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) # Treat None lookup values as null. if rhs == "%s" and rhs_params == [None]: rhs_params = ["null"] if connection.vendor == "mysql": func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params) rhs = rhs % tuple(func) return rhs, rhs_params class JSONIContains(CaseInsensitiveMixin, lookups.IContains): pass JSONField.register_lookup(DataContains) JSONField.register_lookup(ContainedBy) JSONField.register_lookup(HasKey) JSONField.register_lookup(HasKeys) JSONField.register_lookup(HasAnyKeys) JSONField.register_lookup(JSONExact) JSONField.register_lookup(JSONIContains) class KeyTransform(Transform): postgres_operator = "->" postgres_nested_operator = "#>" def __init__(self, key_name, *args, **kwargs): super().__init__(*args, **kwargs) self.key_name = str(key_name) def preprocess_lhs(self, compiler, connection): key_transforms = [self.key_name] previous = self.lhs while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs lhs, params = compiler.compile(previous) if connection.vendor == "oracle": # Escape string-formatting. key_transforms = [key.replace("%", "%%") for key in key_transforms] return lhs, params, key_transforms def as_mysql(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) json_path = compile_json_path(key_transforms) return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,) def as_oracle(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) json_path = compile_json_path(key_transforms) return ( "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" % ((lhs, json_path) * 2) ), tuple(params) * 2 def as_postgresql(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) if len(key_transforms) > 1: sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator) return sql, tuple(params) + (key_transforms,) try: lookup = int(self.key_name) except ValueError: lookup = self.key_name return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,) def as_sqlite(self, compiler, connection): lhs, params, key_transforms = self.preprocess_lhs(compiler, connection) json_path = compile_json_path(key_transforms) datatype_values = ",".join( [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values] ) return ( "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) " "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)" ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3 class KeyTextTransform(KeyTransform): postgres_operator = "->>" postgres_nested_operator = "#>>" class KeyTransformTextLookupMixin: """ Mixin for combining with a lookup expecting a text lhs from a JSONField key lookup. On PostgreSQL, make use of the ->> operator instead of casting key values to text and performing the lookup on the resulting representation. """ def __init__(self, key_transform, *args, **kwargs): if not isinstance(key_transform, KeyTransform): raise TypeError( "Transform should be an instance of KeyTransform in order to " "use this lookup." ) key_text_transform = KeyTextTransform( key_transform.key_name, *key_transform.source_expressions, **key_transform.extra, ) super().__init__(key_text_transform, *args, **kwargs) class KeyTransformIsNull(lookups.IsNull): # key__isnull=False is the same as has_key='key' def as_oracle(self, compiler, connection): sql, params = HasKeyOrArrayIndex( self.lhs.lhs, self.lhs.key_name, ).as_oracle(compiler, connection) if not self.rhs: return sql, params # Column doesn't have a key or IS NULL. lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection) return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params) def as_sqlite(self, compiler, connection): template = "JSON_TYPE(%s, %%s) IS NULL" if not self.rhs: template = "JSON_TYPE(%s, %%s) IS NOT NULL" return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql( compiler, connection, template=template, ) class KeyTransformIn(lookups.In): def resolve_expression_parameter(self, compiler, connection, sql, param): sql, params = super().resolve_expression_parameter( compiler, connection, sql, param, ) if ( not hasattr(param, "as_sql") and not connection.features.has_native_json_field ): if connection.vendor == "oracle": value = json.loads(param) sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" if isinstance(value, (list, dict)): sql = sql % "JSON_QUERY" else: sql = sql % "JSON_VALUE" elif connection.vendor == "mysql" or ( connection.vendor == "sqlite" and params[0] not in connection.ops.jsonfield_datatype_values ): sql = "JSON_EXTRACT(%s, '$')" if connection.vendor == "mysql" and connection.mysql_is_mariadb: sql = "JSON_UNQUOTE(%s)" % sql return sql, params class KeyTransformExact(JSONExact): def process_rhs(self, compiler, connection): if isinstance(self.rhs, KeyTransform): return super(lookups.Exact, self).process_rhs(compiler, connection) rhs, rhs_params = super().process_rhs(compiler, connection) if connection.vendor == "oracle": func = [] sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')" for value in rhs_params: value = json.loads(value) if isinstance(value, (list, dict)): func.append(sql % "JSON_QUERY") else: func.append(sql % "JSON_VALUE") rhs = rhs % tuple(func) elif connection.vendor == "sqlite": func = [] for value in rhs_params: if value in connection.ops.jsonfield_datatype_values: func.append("%s") else: func.append("JSON_EXTRACT(%s, '$')") rhs = rhs % tuple(func) return rhs, rhs_params def as_oracle(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) if rhs_params == ["null"]: # Field has key and it's NULL. has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name) has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection) is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True) is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection) return ( "%s AND %s" % (has_key_sql, is_null_sql), tuple(has_key_params) + tuple(is_null_params), ) return super().as_sql(compiler, connection) class KeyTransformIExact( CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact ): pass class KeyTransformIContains( CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains ): pass class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith): pass class KeyTransformIStartsWith( CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith ): pass class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith): pass class KeyTransformIEndsWith( CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith ): pass class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex): pass class KeyTransformIRegex( CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex ): pass class KeyTransformNumericLookupMixin: def process_rhs(self, compiler, connection): rhs, rhs_params = super().process_rhs(compiler, connection) if not connection.features.has_native_json_field: rhs_params = [json.loads(value) for value in rhs_params] return rhs, rhs_params class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan): pass class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual): pass class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan): pass class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual): pass KeyTransform.register_lookup(KeyTransformIn) KeyTransform.register_lookup(KeyTransformExact) KeyTransform.register_lookup(KeyTransformIExact) KeyTransform.register_lookup(KeyTransformIsNull) KeyTransform.register_lookup(KeyTransformIContains) KeyTransform.register_lookup(KeyTransformStartsWith) KeyTransform.register_lookup(KeyTransformIStartsWith) KeyTransform.register_lookup(KeyTransformEndsWith) KeyTransform.register_lookup(KeyTransformIEndsWith) KeyTransform.register_lookup(KeyTransformRegex) KeyTransform.register_lookup(KeyTransformIRegex) KeyTransform.register_lookup(KeyTransformLt) KeyTransform.register_lookup(KeyTransformLte) KeyTransform.register_lookup(KeyTransformGt) KeyTransform.register_lookup(KeyTransformGte) class KeyTransformFactory: def __init__(self, key_name): self.key_name = key_name def __call__(self, *args, **kwargs): return KeyTransform(self.key_name, *args, **kwargs)