33
44from  django .core .exceptions  import  EmptyResultSet , FullResultSet 
55from  django .db  import  DatabaseError , IntegrityError , NotSupportedError 
6- from  django .db .models .expressions  import  Case , When 
6+ from  django .db .models .expressions  import  Case , Col ,  When 
77from  django .db .models .functions  import  Mod 
88from  django .db .models .lookups  import  Exact 
99from  django .db .models .sql .constants  import  INNER 
@@ -105,6 +105,7 @@ def join(self, compiler, connection):
105105 lhs_fields  =  []
106106 rhs_fields  =  []
107107 # Add a join condition for each pair of joining fields. 
108+  parent_template  =  "parent__field__" 
108109 for  lhs , rhs  in  self .join_fields :
109110 lhs , rhs  =  connection .ops .prepare_join_on_clause (
110111 self .parent_alias , lhs , compiler .collection_name , rhs 
@@ -113,8 +114,41 @@ def join(self, compiler, connection):
113114 # In the lookup stage, the reference to this column doesn't include 
114115 # the collection name. 
115116 rhs_fields .append (rhs .as_mql (compiler , connection ))
117+  # Handle any join conditions besides matching field pairs. 
118+  extra  =  self .join_field .get_extra_restriction (self .table_alias , self .parent_alias )
119+  if  extra :
120+  columns  =  []
121+  for  expr  in  extra .leaves ():
122+  # Determine whether the column needs to be transformed or rerouted 
123+  # as part of the subquery. 
124+  for  hand_side  in  ["lhs" , "rhs" ]:
125+  hand_side_value  =  getattr (expr , hand_side , None )
126+  if  isinstance (hand_side_value , Col ):
127+  # If the column is not part of the joined table, add it to 
128+  # lhs_fields. 
129+  if  hand_side_value .alias  !=  self .table_name :
130+  pos  =  len (lhs_fields )
131+  lhs_fields .append (expr .lhs .as_mql (compiler , connection ))
132+  else :
133+  pos  =  None 
134+  columns .append ((hand_side_value , pos ))
135+  # Replace columns in the extra conditions with new column references 
136+  # based on their rerouted positions in the join pipeline. 
137+  replacements  =  {}
138+  for  col , parent_pos  in  columns :
139+  column_target  =  Col (compiler .collection_name , expr .output_field .__class__ ())
140+  if  parent_pos  is  not None :
141+  target_col  =  f"${ parent_template } { parent_pos }  
142+  column_target .target .db_column  =  target_col 
143+  column_target .target .set_attributes_from_name (target_col )
144+  else :
145+  column_target .target  =  col .target 
146+  replacements [col ] =  column_target 
147+  # Apply the transformed expressions in the extra condition. 
148+  extra_condition  =  [extra .replace_expressions (replacements ).as_mql (compiler , connection )]
149+  else :
150+  extra_condition  =  []
116151
117-  parent_template  =  "parent__field__" 
118152 lookup_pipeline  =  [
119153 {
120154 "$lookup" : {
@@ -140,6 +174,7 @@ def join(self, compiler, connection):
140174 {"$eq" : [f"$${ parent_template } { i }  , field ]}
141175 for  i , field  in  enumerate (rhs_fields )
142176 ]
177+  +  extra_condition 
143178 }
144179 }
145180 }
0 commit comments