1+ import difflib
2+
13from django .core import checks
4+ from django .core .exceptions import FieldDoesNotExist
25from django .db import models
36from django .db .models .fields .related import lazy_related_operation
47from django .db .models .lookups import Transform
@@ -123,7 +126,8 @@ def get_transform(self, name):
123126 transform = super ().get_transform (name )
124127 if transform :
125128 return transform
126- return KeyTransformFactory (name )
129+ field = self .embedded_model ._meta .get_field (name )
130+ return KeyTransformFactory (name , field )
127131
128132 def validate (self , value , model_instance ):
129133 super ().validate (value , model_instance )
@@ -145,9 +149,36 @@ def formfield(self, **kwargs):
145149
146150
147151class KeyTransform (Transform ):
148- def __init__ (self , key_name , * args , ** kwargs ):
152+ def __init__ (self , key_name , ref_field , * args , ** kwargs ):
149153 super ().__init__ (* args , ** kwargs )
150154 self .key_name = str (key_name )
155+ self .ref_field = ref_field
156+
157+ def get_transform (self , name ):
158+ """
159+ Validate that `name` is either a field of an embedded model or a
160+ lookup on an embedded model's field.
161+ """
162+ result = None
163+ if isinstance (self .ref_field , EmbeddedModelField ):
164+ opts = self .ref_field .embedded_model ._meta
165+ new_field = opts .get_field (name )
166+ result = KeyTransformFactory (name , new_field )
167+ else :
168+ if self .ref_field .get_transform (name ) is None :
169+ suggested_lookups = difflib .get_close_matches (name , self .ref_field .get_lookups ())
170+ if suggested_lookups :
171+ suggested_lookups = " or " .join (suggested_lookups )
172+ suggestion = f", perhaps you meant { suggested_lookups } ?"
173+ else :
174+ suggestion = "."
175+ raise FieldDoesNotExist (
176+ f"Unsupported lookup '{ name } ' for "
177+ f"{ self .ref_field .__class__ .__name__ } '{ self .ref_field .name } '"
178+ f"{ suggestion } "
179+ )
180+ result = KeyTransformFactory (name , self .ref_field )
181+ return result
151182
152183 def preprocess_lhs (self , compiler , connection ):
153184 key_transforms = [self .key_name ]
@@ -165,8 +196,9 @@ def as_mql(self, compiler, connection):
165196
166197
167198class KeyTransformFactory :
168- def __init__ (self , key_name ):
199+ def __init__ (self , key_name , ref_field ):
169200 self .key_name = key_name
201+ self .ref_field = ref_field
170202
171203 def __call__ (self , * args , ** kwargs ):
172- return KeyTransform (self .key_name , * args , ** kwargs )
204+ return KeyTransform (self .key_name , self . ref_field , * args , ** kwargs )
0 commit comments