| 
 | 1 | +import itertools  | 
1 | 2 | from collections import defaultdict  | 
2 | 3 | 
 
  | 
 | 4 | +from django.core.checks import Error, Warning  | 
3 | 5 | from django.db import NotSupportedError  | 
4 |  | -from django.db.models import Index  | 
 | 6 | +from django.db.models import FloatField, Index, IntegerField  | 
5 | 7 | from django.db.models.lookups import BuiltinLookup  | 
6 | 8 | from django.db.models.sql.query import Query  | 
7 | 9 | from django.db.models.sql.where import AND, XOR, WhereNode  | 
8 | 10 | from pymongo import ASCENDING, DESCENDING  | 
9 |  | -from pymongo.operations import IndexModel  | 
 | 11 | +from pymongo.operations import IndexModel, SearchIndexModel  | 
 | 12 | + | 
 | 13 | +from django_mongodb_backend.fields import ArrayField  | 
10 | 14 | 
 
  | 
11 | 15 | from .query_utils import process_rhs  | 
12 | 16 | 
 
  | 
@@ -101,6 +105,181 @@ def where_node_idx(self, compiler, connection):  | 
101 | 105 |  return mql  | 
102 | 106 | 
 
  | 
103 | 107 | 
 
  | 
 | 108 | +class SearchIndex(Index):  | 
 | 109 | + suffix = "six"  | 
 | 110 | + _error_id_prefix = "django_mongodb_backend.indexes.SearchIndex"  | 
 | 111 | + | 
 | 112 | + def __init__(self, *, fields=(), name=None):  | 
 | 113 | + super().__init__(fields=fields, name=name)  | 
 | 114 | + | 
 | 115 | + def check(self, model, connection):  | 
 | 116 | + errors = []  | 
 | 117 | + if not connection.features.supports_atlas_search:  | 
 | 118 | + errors.append(  | 
 | 119 | + Warning(  | 
 | 120 | + f"This MongoDB server does not support {self.__class__.__name__}.",  | 
 | 121 | + hint=(  | 
 | 122 | + "The index won't be created. Use an Atlas-enabled version of MongoDB, "  | 
 | 123 | + "or silence this warning if you don't care about it."  | 
 | 124 | + ),  | 
 | 125 | + obj=model,  | 
 | 126 | + id=f"{self._error_id_prefix}.W001",  | 
 | 127 | + )  | 
 | 128 | + )  | 
 | 129 | + return errors  | 
 | 130 | + | 
 | 131 | + def search_index_data_types(self, db_type):  | 
 | 132 | + """  | 
 | 133 | + Map a model field's type to search index type.  | 
 | 134 | + https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings/#data-types  | 
 | 135 | + """  | 
 | 136 | + if db_type in {"double", "int", "long"}:  | 
 | 137 | + return "number"  | 
 | 138 | + if db_type == "binData":  | 
 | 139 | + return "string"  | 
 | 140 | + if db_type == "bool":  | 
 | 141 | + return "boolean"  | 
 | 142 | + if db_type == "object":  | 
 | 143 | + return "document"  | 
 | 144 | + if db_type == "array":  | 
 | 145 | + return "embeddedDocuments"  | 
 | 146 | + return db_type  | 
 | 147 | + | 
 | 148 | + def get_pymongo_index_model(  | 
 | 149 | + self, model, schema_editor, field=None, unique=False, column_prefix=""  | 
 | 150 | + ):  | 
 | 151 | + if not schema_editor.connection.features.supports_atlas_search:  | 
 | 152 | + return None  | 
 | 153 | + fields = {}  | 
 | 154 | + for field_name, _ in self.fields_orders:  | 
 | 155 | + field = model._meta.get_field(field_name)  | 
 | 156 | + type_ = self.search_index_data_types(field.db_type(schema_editor.connection))  | 
 | 157 | + field_path = column_prefix + model._meta.get_field(field_name).column  | 
 | 158 | + fields[field_path] = {"type": type_}  | 
 | 159 | + return SearchIndexModel(  | 
 | 160 | + definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name  | 
 | 161 | + )  | 
 | 162 | + | 
 | 163 | + | 
 | 164 | +class VectorSearchIndex(SearchIndex):  | 
 | 165 | + suffix = "vsi"  | 
 | 166 | + _error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"  | 
 | 167 | + VALID_FIELD_TYPES = frozenset(("boolean", "date", "number", "objectId", "string", "uuid"))  | 
 | 168 | + VALID_SIMILARITIES = frozenset(("cosine", "dotProduct", "euclidean"))  | 
 | 169 | + | 
 | 170 | + def __init__(self, *, fields=(), name=None, similarities):  | 
 | 171 | + super().__init__(fields=fields, name=name)  | 
 | 172 | + self.similarities = similarities  | 
 | 173 | + self._multiple_similarities = isinstance(similarities, tuple | list)  | 
 | 174 | + for func in similarities if self._multiple_similarities else (similarities,):  | 
 | 175 | + if func not in self.VALID_SIMILARITIES:  | 
 | 176 | + raise ValueError(  | 
 | 177 | + f"'{func}' isn't a valid similarity function "  | 
 | 178 | + f"({', '.join(sorted(self.VALID_SIMILARITIES))})."  | 
 | 179 | + )  | 
 | 180 | + seen_fields = set()  | 
 | 181 | + for field_name, _ in self.fields_orders:  | 
 | 182 | + if field_name in seen_fields:  | 
 | 183 | + raise ValueError(f"Field '{field_name}' is duplicated in fields.")  | 
 | 184 | + seen_fields.add(field_name)  | 
 | 185 | + | 
 | 186 | + def check(self, model, connection):  | 
 | 187 | + errors = super().check(model, connection)  | 
 | 188 | + num_arrayfields = 0  | 
 | 189 | + for field_name, _ in self.fields_orders:  | 
 | 190 | + field = model._meta.get_field(field_name)  | 
 | 191 | + if isinstance(field, ArrayField):  | 
 | 192 | + num_arrayfields += 1  | 
 | 193 | + try:  | 
 | 194 | + int(field.size)  | 
 | 195 | + except (ValueError, TypeError):  | 
 | 196 | + errors.append(  | 
 | 197 | + Error(  | 
 | 198 | + f"VectorSearchIndex requires 'size' on field '{field_name}'.",  | 
 | 199 | + obj=model,  | 
 | 200 | + id=f"{self._error_id_prefix}.E002",  | 
 | 201 | + )  | 
 | 202 | + )  | 
 | 203 | + if not isinstance(field.base_field, FloatField | IntegerField):  | 
 | 204 | + errors.append(  | 
 | 205 | + Error(  | 
 | 206 | + "VectorSearchIndex requires the base field of "  | 
 | 207 | + f"ArrayField '{field.name}' to be FloatField or "  | 
 | 208 | + "IntegerField but is "  | 
 | 209 | + f"{field.base_field.get_internal_type()}.",  | 
 | 210 | + obj=model,  | 
 | 211 | + id=f"{self._error_id_prefix}.E003",  | 
 | 212 | + )  | 
 | 213 | + )  | 
 | 214 | + else:  | 
 | 215 | + search_type = self.search_index_data_types(field.db_type(connection))  | 
 | 216 | + if search_type not in self.VALID_FIELD_TYPES:  | 
 | 217 | + errors.append(  | 
 | 218 | + Error(  | 
 | 219 | + "VectorSearchIndex does not support field "  | 
 | 220 | + f"'{field_name}' ({field.get_internal_type()}).",  | 
 | 221 | + obj=model,  | 
 | 222 | + id=f"{self._error_id_prefix}.E004",  | 
 | 223 | + hint=f"Allowed types are {', '.join(sorted(self.VALID_FIELD_TYPES))}.",  | 
 | 224 | + )  | 
 | 225 | + )  | 
 | 226 | + if self._multiple_similarities and num_arrayfields != len(self.similarities):  | 
 | 227 | + errors.append(  | 
 | 228 | + Error(  | 
 | 229 | + f"VectorSearchIndex requires the same number of similarities "  | 
 | 230 | + f"and vector fields; {model._meta.object_name} has "  | 
 | 231 | + f"{num_arrayfields} ArrayField(s) but similarities "  | 
 | 232 | + f"has {len(self.similarities)} element(s).",  | 
 | 233 | + obj=model,  | 
 | 234 | + id=f"{self._error_id_prefix}.E005",  | 
 | 235 | + )  | 
 | 236 | + )  | 
 | 237 | + if num_arrayfields == 0:  | 
 | 238 | + errors.append(  | 
 | 239 | + Error(  | 
 | 240 | + "VectorSearchIndex requires at least one ArrayField to " "store vector data.",  | 
 | 241 | + obj=model,  | 
 | 242 | + id=f"{self._error_id_prefix}.E006",  | 
 | 243 | + hint="If you want to perform search operations without vectors, "  | 
 | 244 | + "use SearchIndex instead.",  | 
 | 245 | + )  | 
 | 246 | + )  | 
 | 247 | + return errors  | 
 | 248 | + | 
 | 249 | + def deconstruct(self):  | 
 | 250 | + path, args, kwargs = super().deconstruct()  | 
 | 251 | + kwargs["similarities"] = self.similarities  | 
 | 252 | + return path, args, kwargs  | 
 | 253 | + | 
 | 254 | + def get_pymongo_index_model(  | 
 | 255 | + self, model, schema_editor, field=None, unique=False, column_prefix=""  | 
 | 256 | + ):  | 
 | 257 | + if not schema_editor.connection.features.supports_atlas_search:  | 
 | 258 | + return None  | 
 | 259 | + similarities = (  | 
 | 260 | + itertools.cycle([self.similarities])  | 
 | 261 | + if not self._multiple_similarities  | 
 | 262 | + else iter(self.similarities)  | 
 | 263 | + )  | 
 | 264 | + fields = []  | 
 | 265 | + for field_name, _ in self.fields_orders:  | 
 | 266 | + field_ = model._meta.get_field(field_name)  | 
 | 267 | + field_path = column_prefix + model._meta.get_field(field_name).column  | 
 | 268 | + mappings = {"path": field_path}  | 
 | 269 | + if isinstance(field_, ArrayField):  | 
 | 270 | + mappings.update(  | 
 | 271 | + {  | 
 | 272 | + "type": "vector",  | 
 | 273 | + "numDimensions": int(field_.size),  | 
 | 274 | + "similarity": next(similarities),  | 
 | 275 | + }  | 
 | 276 | + )  | 
 | 277 | + else:  | 
 | 278 | + mappings["type"] = "filter"  | 
 | 279 | + fields.append(mappings)  | 
 | 280 | + return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")  | 
 | 281 | + | 
 | 282 | + | 
104 | 283 | def register_indexes():  | 
105 | 284 |  BuiltinLookup.as_mql_idx = builtin_lookup_idx  | 
106 | 285 |  Index._get_condition_mql = _get_condition_mql  | 
 | 
0 commit comments