1+ use  std:: collections:: { HashMap ,  HashSet } ; 
2+ 
13use  pg_schema_cache:: SchemaCache ; 
4+ use  pg_treesitter_queries:: { 
5+  queries:: { self ,  QueryResult } , 
6+  TreeSitterQueriesExecutor , 
7+ } ; 
28
39use  crate :: CompletionParams ; 
410
@@ -52,6 +58,9 @@ pub(crate) struct CompletionContext<'a> {
5258 pub  schema_name :  Option < String > , 
5359 pub  wrapping_clause_type :  Option < ClauseType > , 
5460 pub  is_invocation :  bool , 
61+  pub  wrapping_statement_range :  Option < tree_sitter:: Range > , 
62+ 
63+  pub  mentioned_relations :  HashMap < Option < String > ,  HashSet < String > > , 
5564} 
5665
5766impl < ' a >  CompletionContext < ' a >  { 
@@ -61,18 +70,56 @@ impl<'a> CompletionContext<'a> {
6170 text :  & params. text , 
6271 schema_cache :  params. schema , 
6372 position :  usize:: from ( params. position ) , 
64- 
6573 ts_node :  None , 
6674 schema_name :  None , 
6775 wrapping_clause_type :  None , 
76+  wrapping_statement_range :  None , 
6877 is_invocation :  false , 
78+  mentioned_relations :  HashMap :: new ( ) , 
6979 } ; 
7080
7181 ctx. gather_tree_context ( ) ; 
82+  ctx. gather_info_from_ts_queries ( ) ; 
7283
7384 ctx
7485 } 
7586
87+  fn  gather_info_from_ts_queries ( & mut  self )  { 
88+  let  tree = match  self . tree . as_ref ( )  { 
89+  None  => return , 
90+  Some ( t)  => t, 
91+  } ; 
92+ 
93+  let  stmt_range = self . wrapping_statement_range . as_ref ( ) ; 
94+  let  sql = self . text ; 
95+ 
96+  let  mut  executor = TreeSitterQueriesExecutor :: new ( tree. root_node ( ) ,  sql) ; 
97+ 
98+  executor. add_query_results :: < queries:: RelationMatch > ( ) ; 
99+ 
100+  for  relation_match in  executor. get_iter ( stmt_range)  { 
101+  match  relation_match { 
102+  QueryResult :: Relation ( r)  => { 
103+  let  schema_name = r. get_schema ( sql) ; 
104+  let  table_name = r. get_table ( sql) ; 
105+ 
106+  let  current = self . mentioned_relations . get_mut ( & schema_name) ; 
107+ 
108+  match  current { 
109+  Some ( c)  => { 
110+  c. insert ( table_name) ; 
111+  } 
112+  None  => { 
113+  let  mut  new = HashSet :: new ( ) ; 
114+  new. insert ( table_name) ; 
115+  self . mentioned_relations . insert ( schema_name,  new) ; 
116+  } 
117+  } ; 
118+  } 
119+  } ; 
120+  } 
121+  } 
122+ 
76123 pub  fn  get_ts_node_content ( & self ,  ts_node :  tree_sitter:: Node < ' a > )  -> Option < & ' a  str >  { 
77124 let  source = self . text ; 
78125 match  ts_node. utf8_text ( source. as_bytes ( ) )  { 
@@ -100,36 +147,38 @@ impl<'a> CompletionContext<'a> {
100147 * We'll therefore adjust the cursor position such that it meets the last node of the AST. 
101148 * `select * from use {}` becomes `select * from use{}`. 
102149 */ 
103-  let  current_node_kind  = cursor. node ( ) . kind ( ) ; 
150+  let  current_node  = cursor. node ( ) ; 
104151 while  cursor. goto_first_child_for_byte ( self . position ) . is_none ( )  && self . position  > 0  { 
105152 self . position  -= 1 ; 
106153 } 
107154
108-  self . gather_context_from_node ( cursor,  current_node_kind ) ; 
155+  self . gather_context_from_node ( cursor,  current_node ) ; 
109156 } 
110157
111158 fn  gather_context_from_node ( 
112159 & mut  self , 
113160 mut  cursor :  tree_sitter:: TreeCursor < ' a > , 
114-  previous_node_kind :   & str , 
161+  previous_node :  tree_sitter :: Node < ' a > , 
115162 )  { 
116163 let  current_node = cursor. node ( ) ; 
117-  let  current_node_kind = current_node. kind ( ) ; 
118164
119165 // prevent infinite recursion – this can happen if we only have a PROGRAM node 
120-  if  current_node_kind  == previous_node_kind  { 
166+  if  current_node . kind ( )  == previous_node . kind ( )  { 
121167 self . ts_node  = Some ( current_node) ; 
122168 return ; 
123169 } 
124170
125-  match  previous_node_kind { 
126-  "statement"  => self . wrapping_clause_type  = current_node_kind. try_into ( ) . ok ( ) , 
171+  match  previous_node. kind ( )  { 
172+  "statement"  | "subquery"  => { 
173+  self . wrapping_clause_type  = current_node. kind ( ) . try_into ( ) . ok ( ) ; 
174+  self . wrapping_statement_range  = Some ( previous_node. range ( ) ) ; 
175+  } 
127176 "invocation"  => self . is_invocation  = true , 
128177
129178 _ => { } 
130179 } 
131180
132-  match  current_node_kind  { 
181+  match  current_node . kind ( )  { 
133182 "object_reference"  => { 
134183 let  txt = self . get_ts_node_content ( current_node) ; 
135184 if  let  Some ( txt)  = txt { 
@@ -159,7 +208,7 @@ impl<'a> CompletionContext<'a> {
159208 } 
160209
161210 cursor. goto_first_child_for_byte ( self . position ) ; 
162-  self . gather_context_from_node ( cursor,  current_node_kind ) ; 
211+  self . gather_context_from_node ( cursor,  current_node ) ; 
163212 } 
164213} 
165214
@@ -209,7 +258,7 @@ mod tests {
209258 ] ; 
210259
211260 for  ( query,  expected_clause)  in  test_cases { 
212-  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) ) ; 
261+  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) . into ( ) ) ; 
213262
214263 let  tree = get_tree ( text. as_str ( ) ) ; 
215264
@@ -242,7 +291,7 @@ mod tests {
242291 ] ; 
243292
244293 for  ( query,  expected_schema)  in  test_cases { 
245-  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) ) ; 
294+  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) . into ( ) ) ; 
246295
247296 let  tree = get_tree ( text. as_str ( ) ) ; 
248297 let  params = crate :: CompletionParams  { 
@@ -276,7 +325,7 @@ mod tests {
276325 ] ; 
277326
278327 for  ( query,  is_invocation)  in  test_cases { 
279-  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) ) ; 
328+  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) . into ( ) ) ; 
280329
281330 let  tree = get_tree ( text. as_str ( ) ) ; 
282331 let  params = crate :: CompletionParams  { 
@@ -300,7 +349,7 @@ mod tests {
300349 ] ; 
301350
302351 for  query in  cases { 
303-  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) ) ; 
352+  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) . into ( ) ) ; 
304353
305354 let  tree = get_tree ( text. as_str ( ) ) ; 
306355
@@ -328,7 +377,7 @@ mod tests {
328377 fn  does_not_fail_on_trailing_whitespace ( )  { 
329378 let  query = format ! ( "select * from {}" ,  CURSOR_POS ) ; 
330379
331-  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) ) ; 
380+  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) . into ( ) ) ; 
332381
333382 let  tree = get_tree ( text. as_str ( ) ) ; 
334383
@@ -354,7 +403,7 @@ mod tests {
354403 fn  does_not_fail_with_empty_statements ( )  { 
355404 let  query = format ! ( "{}" ,  CURSOR_POS ) ; 
356405
357-  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) ) ; 
406+  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) . into ( ) ) ; 
358407
359408 let  tree = get_tree ( text. as_str ( ) ) ; 
360409
@@ -379,7 +428,7 @@ mod tests {
379428 // is selecting a certain column name, such as `frozen_account`. 
380429 let  query = format ! ( "select * fro{}" ,  CURSOR_POS ) ; 
381430
382-  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) ) ; 
431+  let  ( position,  text)  = get_text_and_position ( query. as_str ( ) . into ( ) ) ; 
383432
384433 let  tree = get_tree ( text. as_str ( ) ) ; 
385434
0 commit comments