|  | 
|  | 1 | +import json | 
|  | 2 | + | 
|  | 3 | +from django.contrib.postgres.validators import ArrayMaxLengthValidator | 
|  | 4 | +from django.core import checks, exceptions | 
|  | 5 | +from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value | 
|  | 6 | +from django.db.models.fields.mixins import CheckFieldDefaultMixin | 
|  | 7 | +from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup | 
|  | 8 | +from django.utils.translation import gettext_lazy as _ | 
|  | 9 | + | 
|  | 10 | +from ..forms import SimpleArrayField | 
|  | 11 | +from ..query_utils import process_lhs, process_rhs | 
|  | 12 | +from ..utils import prefix_validation_error | 
|  | 13 | + | 
|  | 14 | +__all__ = ["ArrayField"] | 
|  | 15 | + | 
|  | 16 | + | 
|  | 17 | +class AttributeSetter: | 
|  | 18 | + def __init__(self, name, value): | 
|  | 19 | + setattr(self, name, value) | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +class ArrayField(CheckFieldDefaultMixin, Field): | 
|  | 23 | + empty_strings_allowed = False | 
|  | 24 | + default_error_messages = { | 
|  | 25 | + "item_invalid": _("Item %(nth)s in the array did not validate:"), | 
|  | 26 | + "nested_array_mismatch": _("Nested arrays must have the same length."), | 
|  | 27 | + } | 
|  | 28 | + _default_hint = ("list", "[]") | 
|  | 29 | + | 
|  | 30 | + def __init__(self, base_field, size=None, **kwargs): | 
|  | 31 | + self.base_field = base_field | 
|  | 32 | + self.size = size | 
|  | 33 | + if self.size: | 
|  | 34 | + self.default_validators = [ | 
|  | 35 | + *self.default_validators, | 
|  | 36 | + ArrayMaxLengthValidator(self.size), | 
|  | 37 | + ] | 
|  | 38 | + # For performance, only add a from_db_value() method if the base field | 
|  | 39 | + # implements it. | 
|  | 40 | + if hasattr(self.base_field, "from_db_value"): | 
|  | 41 | + self.from_db_value = self._from_db_value | 
|  | 42 | + super().__init__(**kwargs) | 
|  | 43 | + | 
|  | 44 | + @property | 
|  | 45 | + def model(self): | 
|  | 46 | + try: | 
|  | 47 | + return self.__dict__["model"] | 
|  | 48 | + except KeyError: | 
|  | 49 | + raise AttributeError( | 
|  | 50 | + "'%s' object has no attribute 'model'" % self.__class__.__name__ | 
|  | 51 | + ) from None | 
|  | 52 | + | 
|  | 53 | + @model.setter | 
|  | 54 | + def model(self, model): | 
|  | 55 | + self.__dict__["model"] = model | 
|  | 56 | + self.base_field.model = model | 
|  | 57 | + | 
|  | 58 | + @classmethod | 
|  | 59 | + def _choices_is_value(cls, value): | 
|  | 60 | + return isinstance(value, list | tuple) or super()._choices_is_value(value) | 
|  | 61 | + | 
|  | 62 | + def check(self, **kwargs): | 
|  | 63 | + errors = super().check(**kwargs) | 
|  | 64 | + if self.base_field.remote_field: | 
|  | 65 | + errors.append( | 
|  | 66 | + checks.Error( | 
|  | 67 | + "Base field for array cannot be a related field.", | 
|  | 68 | + obj=self, | 
|  | 69 | + id="django_mongodb_backend.array.E002", | 
|  | 70 | + ) | 
|  | 71 | + ) | 
|  | 72 | + else: | 
|  | 73 | + base_checks = self.base_field.check() | 
|  | 74 | + if base_checks: | 
|  | 75 | + error_messages = "\n ".join( | 
|  | 76 | + f"{base_check.msg} ({base_check.id})" | 
|  | 77 | + for base_check in base_checks | 
|  | 78 | + if isinstance(base_check, checks.Error) | 
|  | 79 | + ) | 
|  | 80 | + if error_messages: | 
|  | 81 | + errors.append( | 
|  | 82 | + checks.Error( | 
|  | 83 | + f"Base field for array has errors:\n {error_messages}", | 
|  | 84 | + obj=self, | 
|  | 85 | + id="django_mongodb_backend.array.E001", | 
|  | 86 | + ) | 
|  | 87 | + ) | 
|  | 88 | + warning_messages = "\n ".join( | 
|  | 89 | + f"{base_check.msg} ({base_check.id})" | 
|  | 90 | + for base_check in base_checks | 
|  | 91 | + if isinstance(base_check, checks.Warning) | 
|  | 92 | + ) | 
|  | 93 | + if warning_messages: | 
|  | 94 | + errors.append( | 
|  | 95 | + checks.Warning( | 
|  | 96 | + f"Base field for array has warnings:\n {warning_messages}", | 
|  | 97 | + obj=self, | 
|  | 98 | + id="django_mongodb_backend.array.W004", | 
|  | 99 | + ) | 
|  | 100 | + ) | 
|  | 101 | + return errors | 
|  | 102 | + | 
|  | 103 | + def set_attributes_from_name(self, name): | 
|  | 104 | + super().set_attributes_from_name(name) | 
|  | 105 | + self.base_field.set_attributes_from_name(name) | 
|  | 106 | + | 
|  | 107 | + @property | 
|  | 108 | + def description(self): | 
|  | 109 | + return f"Array of {self.base_field.description}" | 
|  | 110 | + | 
|  | 111 | + def db_type(self, connection): | 
|  | 112 | + return "array" | 
|  | 113 | + | 
|  | 114 | + def get_db_prep_value(self, value, connection, prepared=False): | 
|  | 115 | + if isinstance(value, list | tuple): | 
|  | 116 | + # Workaround for https://code.djangoproject.com/ticket/35982 | 
|  | 117 | + # (fixed in Django 5.2). | 
|  | 118 | + if isinstance(self.base_field, DecimalField): | 
|  | 119 | + return [self.base_field.get_db_prep_save(i, connection) for i in value] | 
|  | 120 | + return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value] | 
|  | 121 | + return value | 
|  | 122 | + | 
|  | 123 | + def deconstruct(self): | 
|  | 124 | + name, path, args, kwargs = super().deconstruct() | 
|  | 125 | + if path == "django_mongodb_backend.fields.array.ArrayField": | 
|  | 126 | + path = "django_mongodb_backend.fields.ArrayField" | 
|  | 127 | + kwargs.update( | 
|  | 128 | + { | 
|  | 129 | + "base_field": self.base_field.clone(), | 
|  | 130 | + "size": self.size, | 
|  | 131 | + } | 
|  | 132 | + ) | 
|  | 133 | + return name, path, args, kwargs | 
|  | 134 | + | 
|  | 135 | + def to_python(self, value): | 
|  | 136 | + if isinstance(value, str): | 
|  | 137 | + # Assume value is being deserialized. | 
|  | 138 | + vals = json.loads(value) | 
|  | 139 | + value = [self.base_field.to_python(val) for val in vals] | 
|  | 140 | + return value | 
|  | 141 | + | 
|  | 142 | + def _from_db_value(self, value, expression, connection): | 
|  | 143 | + if value is None: | 
|  | 144 | + return value | 
|  | 145 | + return [self.base_field.from_db_value(item, expression, connection) for item in value] | 
|  | 146 | + | 
|  | 147 | + def value_to_string(self, obj): | 
|  | 148 | + values = [] | 
|  | 149 | + vals = self.value_from_object(obj) | 
|  | 150 | + base_field = self.base_field | 
|  | 151 | + | 
|  | 152 | + for val in vals: | 
|  | 153 | + if val is None: | 
|  | 154 | + values.append(None) | 
|  | 155 | + else: | 
|  | 156 | + obj = AttributeSetter(base_field.attname, val) | 
|  | 157 | + values.append(base_field.value_to_string(obj)) | 
|  | 158 | + return json.dumps(values) | 
|  | 159 | + | 
|  | 160 | + def get_transform(self, name): | 
|  | 161 | + transform = super().get_transform(name) | 
|  | 162 | + if transform: | 
|  | 163 | + return transform | 
|  | 164 | + if "_" not in name: | 
|  | 165 | + try: | 
|  | 166 | + index = int(name) | 
|  | 167 | + except ValueError: | 
|  | 168 | + pass | 
|  | 169 | + else: | 
|  | 170 | + return IndexTransformFactory(index, self.base_field) | 
|  | 171 | + try: | 
|  | 172 | + start, end = name.split("_") | 
|  | 173 | + start = int(start) | 
|  | 174 | + end = int(end) | 
|  | 175 | + except ValueError: | 
|  | 176 | + pass | 
|  | 177 | + else: | 
|  | 178 | + return SliceTransformFactory(start, end) | 
|  | 179 | + | 
|  | 180 | + def validate(self, value, model_instance): | 
|  | 181 | + super().validate(value, model_instance) | 
|  | 182 | + for index, part in enumerate(value): | 
|  | 183 | + try: | 
|  | 184 | + self.base_field.validate(part, model_instance) | 
|  | 185 | + except exceptions.ValidationError as error: | 
|  | 186 | + raise prefix_validation_error( | 
|  | 187 | + error, | 
|  | 188 | + prefix=self.error_messages["item_invalid"], | 
|  | 189 | + code="item_invalid", | 
|  | 190 | + params={"nth": index + 1}, | 
|  | 191 | + ) from None | 
|  | 192 | + if isinstance(self.base_field, ArrayField) and len({len(i) for i in value}) > 1: | 
|  | 193 | + raise exceptions.ValidationError( | 
|  | 194 | + self.error_messages["nested_array_mismatch"], | 
|  | 195 | + code="nested_array_mismatch", | 
|  | 196 | + ) | 
|  | 197 | + | 
|  | 198 | + def run_validators(self, value): | 
|  | 199 | + super().run_validators(value) | 
|  | 200 | + for index, part in enumerate(value): | 
|  | 201 | + try: | 
|  | 202 | + self.base_field.run_validators(part) | 
|  | 203 | + except exceptions.ValidationError as error: | 
|  | 204 | + raise prefix_validation_error( | 
|  | 205 | + error, | 
|  | 206 | + prefix=self.error_messages["item_invalid"], | 
|  | 207 | + code="item_invalid", | 
|  | 208 | + params={"nth": index + 1}, | 
|  | 209 | + ) from None | 
|  | 210 | + | 
|  | 211 | + def formfield(self, **kwargs): | 
|  | 212 | + return super().formfield( | 
|  | 213 | + **{ | 
|  | 214 | + "form_class": SimpleArrayField, | 
|  | 215 | + "base_field": self.base_field.formfield(), | 
|  | 216 | + "max_length": self.size, | 
|  | 217 | + **kwargs, | 
|  | 218 | + } | 
|  | 219 | + ) | 
|  | 220 | + | 
|  | 221 | + | 
|  | 222 | +class Array(Func): | 
|  | 223 | + def as_mql(self, compiler, connection): | 
|  | 224 | + return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()] | 
|  | 225 | + | 
|  | 226 | + | 
|  | 227 | +class ArrayRHSMixin: | 
|  | 228 | + def __init__(self, lhs, rhs): | 
|  | 229 | + if isinstance(rhs, tuple | list): | 
|  | 230 | + expressions = [] | 
|  | 231 | + for value in rhs: | 
|  | 232 | + if not hasattr(value, "resolve_expression"): | 
|  | 233 | + field = lhs.output_field | 
|  | 234 | + value = Value(field.base_field.get_prep_value(value)) | 
|  | 235 | + expressions.append(value) | 
|  | 236 | + rhs = Array(*expressions) | 
|  | 237 | + super().__init__(lhs, rhs) | 
|  | 238 | + | 
|  | 239 | + | 
|  | 240 | +@ArrayField.register_lookup | 
|  | 241 | +class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): | 
|  | 242 | + lookup_name = "contains" | 
|  | 243 | + | 
|  | 244 | + def as_mql(self, compiler, connection): | 
|  | 245 | + lhs_mql = process_lhs(self, compiler, connection) | 
|  | 246 | + value = process_rhs(self, compiler, connection) | 
|  | 247 | + return {"$and": [{"$ne": [lhs_mql, None]}, {"$setIsSubset": [value, lhs_mql]}]} | 
|  | 248 | + | 
|  | 249 | + | 
|  | 250 | +@ArrayField.register_lookup | 
|  | 251 | +class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): | 
|  | 252 | + lookup_name = "contained_by" | 
|  | 253 | + | 
|  | 254 | + def as_mql(self, compiler, connection): | 
|  | 255 | + lhs_mql = process_lhs(self, compiler, connection) | 
|  | 256 | + value = process_rhs(self, compiler, connection) | 
|  | 257 | + return { | 
|  | 258 | + "$and": [ | 
|  | 259 | + {"$ne": [lhs_mql, None]}, | 
|  | 260 | + {"$ne": [value, None]}, | 
|  | 261 | + {"$setIsSubset": [lhs_mql, value]}, | 
|  | 262 | + ] | 
|  | 263 | + } | 
|  | 264 | + | 
|  | 265 | + | 
|  | 266 | +@ArrayField.register_lookup | 
|  | 267 | +class ArrayExact(ArrayRHSMixin, Exact): | 
|  | 268 | + pass | 
|  | 269 | + | 
|  | 270 | + | 
|  | 271 | +@ArrayField.register_lookup | 
|  | 272 | +class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): | 
|  | 273 | + lookup_name = "overlap" | 
|  | 274 | + | 
|  | 275 | + def as_mql(self, compiler, connection): | 
|  | 276 | + lhs_mql = process_lhs(self, compiler, connection) | 
|  | 277 | + value = process_rhs(self, compiler, connection) | 
|  | 278 | + return { | 
|  | 279 | + "$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}] | 
|  | 280 | + } | 
|  | 281 | + | 
|  | 282 | + | 
|  | 283 | +@ArrayField.register_lookup | 
|  | 284 | +class ArrayLenTransform(Transform): | 
|  | 285 | + lookup_name = "len" | 
|  | 286 | + output_field = IntegerField() | 
|  | 287 | + | 
|  | 288 | + def as_mql(self, compiler, connection): | 
|  | 289 | + lhs_mql = process_lhs(self, compiler, connection) | 
|  | 290 | + return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}} | 
|  | 291 | + | 
|  | 292 | + | 
|  | 293 | +@ArrayField.register_lookup | 
|  | 294 | +class ArrayInLookup(In): | 
|  | 295 | + def get_prep_lookup(self): | 
|  | 296 | + values = super().get_prep_lookup() | 
|  | 297 | + if hasattr(values, "resolve_expression"): | 
|  | 298 | + return values | 
|  | 299 | + # process_rhs() expects hashable values, so convert lists to tuples. | 
|  | 300 | + prepared_values = [] | 
|  | 301 | + for value in values: | 
|  | 302 | + if hasattr(value, "resolve_expression"): | 
|  | 303 | + prepared_values.append(value) | 
|  | 304 | + else: | 
|  | 305 | + prepared_values.append(tuple(value)) | 
|  | 306 | + return prepared_values | 
|  | 307 | + | 
|  | 308 | + | 
|  | 309 | +class IndexTransform(Transform): | 
|  | 310 | + def __init__(self, index, base_field, *args, **kwargs): | 
|  | 311 | + super().__init__(*args, **kwargs) | 
|  | 312 | + self.index = index | 
|  | 313 | + self.base_field = base_field | 
|  | 314 | + | 
|  | 315 | + def as_mql(self, compiler, connection): | 
|  | 316 | + lhs_mql = process_lhs(self, compiler, connection) | 
|  | 317 | + return {"$arrayElemAt": [lhs_mql, self.index]} | 
|  | 318 | + | 
|  | 319 | + @property | 
|  | 320 | + def output_field(self): | 
|  | 321 | + return self.base_field | 
|  | 322 | + | 
|  | 323 | + | 
|  | 324 | +class IndexTransformFactory: | 
|  | 325 | + def __init__(self, index, base_field): | 
|  | 326 | + self.index = index | 
|  | 327 | + self.base_field = base_field | 
|  | 328 | + | 
|  | 329 | + def __call__(self, *args, **kwargs): | 
|  | 330 | + return IndexTransform(self.index, self.base_field, *args, **kwargs) | 
|  | 331 | + | 
|  | 332 | + | 
|  | 333 | +class SliceTransform(Transform): | 
|  | 334 | + def __init__(self, start, end, *args, **kwargs): | 
|  | 335 | + super().__init__(*args, **kwargs) | 
|  | 336 | + self.start = start | 
|  | 337 | + self.end = end | 
|  | 338 | + | 
|  | 339 | + def as_mql(self, compiler, connection): | 
|  | 340 | + lhs_mql = process_lhs(self, compiler, connection) | 
|  | 341 | + return {"$slice": [lhs_mql, self.start, self.end]} | 
|  | 342 | + | 
|  | 343 | + | 
|  | 344 | +class SliceTransformFactory: | 
|  | 345 | + def __init__(self, start, end): | 
|  | 346 | + self.start = start | 
|  | 347 | + self.end = end | 
|  | 348 | + | 
|  | 349 | + def __call__(self, *args, **kwargs): | 
|  | 350 | + return SliceTransform(self.start, self.end, *args, **kwargs) | 
0 commit comments