1- use  pgt_console:: { 
2-  fmt:: { Formatter ,  HTML } , 
3-  markup, 
4- } ; 
5- use  pgt_diagnostics:: PrintDiagnostic ; 
6- use  pgt_typecheck:: { TypecheckParams ,  check_sql} ; 
1+ use  pgt_console:: fmt:: { Formatter ,  HTML } ; 
2+ use  pgt_diagnostics:: Diagnostic ; 
3+ use  pgt_typecheck:: { IdentifierType ,  TypecheckParams ,  TypedIdentifier ,  check_sql} ; 
74use  sqlx:: { Executor ,  PgPool } ; 
5+ use  std:: fmt:: Write ; 
86
9- async   fn   test ( name :   & str ,   query :   & str ,   setup :   Option < & str > ,   test_db :   & PgPool )  { 
10-  if   let   Some ( setup )  = setup  { 
11-   test_db 
12-    . execute ( setup ) 
13-    . await 
14-    . expect ( "Failed to setup test database" ) ; 
15-   } 
7+ struct   TestSetup < ' a >  { 
8+  name :   & ' a   str , 
9+  query :   & ' a   str , 
10+  setup :   Option < & ' a   str > , 
11+  test_db :   & ' a   PgPool , 
12+  typed_identifiers :   Vec < TypedIdentifier > , 
13+ } 
1614
17-  let  mut  parser = tree_sitter:: Parser :: new ( ) ; 
18-  parser
19-  . set_language ( & pgt_treesitter_grammar:: LANGUAGE . into ( ) ) 
20-  . expect ( "Error loading sql language" ) ; 
21- 
22-  let  schema_cache = pgt_schema_cache:: SchemaCache :: load ( test_db) 
23-  . await 
24-  . expect ( "Failed to load Schema Cache" ) ; 
25- 
26-  let  root = pgt_query:: parse ( query) 
27-  . unwrap ( ) 
28-  . into_root ( ) 
29-  . expect ( "Failed to parse query" ) ; 
30-  let  tree = parser. parse ( query,  None ) . unwrap ( ) ; 
31- 
32-  let  conn = & test_db; 
33-  let  result = check_sql ( TypecheckParams  { 
34-  conn, 
35-  sql :  query, 
36-  ast :  & root, 
37-  tree :  & tree, 
38-  schema_cache :  & schema_cache, 
39-  search_path_patterns :  vec ! [ ] , 
40-  identifiers :  vec ! [ ] , 
41-  } ) 
42-  . await ; 
15+ impl  TestSetup < ' _ >  { 
16+  async  fn  test ( self )  { 
17+  if  let  Some ( setup)  = self . setup  { 
18+  self . test_db 
19+  . execute ( setup) 
20+  . await 
21+  . expect ( "Failed to setup test selfbase" ) ; 
22+  } 
4323
44-  let  mut  content = vec ! [ ] ; 
45-  let  mut  writer = HTML :: new ( & mut  content) ; 
24+  let  mut  parser = tree_sitter:: Parser :: new ( ) ; 
25+  parser
26+  . set_language ( & pgt_treesitter_grammar:: LANGUAGE . into ( ) ) 
27+  . expect ( "Error loading sql language" ) ; 
4628
47-  Formatter :: new ( & mut  writer) 
48-  . write_markup ( markup !  { 
49-  { PrintDiagnostic :: simple( & result. unwrap( ) . unwrap( ) ) } 
29+  let  schema_cache = pgt_schema_cache:: SchemaCache :: load ( self . test_db ) 
30+  . await 
31+  . expect ( "Failed to load Schema Cache" ) ; 
32+ 
33+  let  root = pgt_query:: parse ( self . query ) 
34+  . unwrap ( ) 
35+  . into_root ( ) 
36+  . expect ( "Failed to parse query" ) ; 
37+  let  tree = parser. parse ( self . query ,  None ) . unwrap ( ) ; 
38+ 
39+  let  result = check_sql ( TypecheckParams  { 
40+  conn :  self . test_db , 
41+  sql :  self . query , 
42+  ast :  & root, 
43+  tree :  & tree, 
44+  schema_cache :  & schema_cache, 
45+  identifiers :  self . typed_identifiers , 
46+  search_path_patterns :  vec ! [ ] , 
5047 } ) 
51-  . unwrap ( ) ; 
48+  . await ; 
49+ 
50+  assert ! ( 
51+  result. is_ok( ) , 
52+  "Got Typechecking error: {}" , 
53+  result. unwrap_err( ) 
54+  ) ; 
5255
53-  let  content  = String :: from_utf8 ( content ) . unwrap ( ) ; 
56+    let  maybe_diagnostic  = result . unwrap ( ) ; 
5457
55-  insta:: with_settings!( { 
56-  prepend_module_to_snapshot => false , 
57-  } ,  { 
58-  insta:: assert_snapshot!( name,  content) ; 
59-  } ) ; 
58+  let  content = match  maybe_diagnostic { 
59+  Some ( d)  => { 
60+  let  mut  result = String :: new ( ) ; 
61+ 
62+  if  let  Some ( span)  = d. location ( ) . span  { 
63+  for  ( idx,  c)  in  self . query . char_indices ( )  { 
64+  if  pgt_text_size:: TextSize :: new ( idx. try_into ( ) . unwrap ( ) )  == span. start ( )  { 
65+  result. push_str ( "~~~" ) ; 
66+  } 
67+  if  pgt_text_size:: TextSize :: new ( idx. try_into ( ) . unwrap ( ) )  == span. end ( )  { 
68+  result. push_str ( "~~~" ) ; 
69+  } 
70+  result. push ( c) ; 
71+  } 
72+  }  else  { 
73+  result. push_str ( "~~~" ) ; 
74+  result. push_str ( self . query ) ; 
75+  result. push_str ( "~~~" ) ; 
76+  } 
77+ 
78+  writeln ! ( & mut  result) . unwrap ( ) ; 
79+  writeln ! ( & mut  result) . unwrap ( ) ; 
80+ 
81+  let  mut  msg_content = vec ! [ ] ; 
82+  let  mut  writer = HTML :: new ( & mut  msg_content) ; 
83+  let  mut  formatter = Formatter :: new ( & mut  writer) ; 
84+  d. message ( & mut  formatter) . unwrap ( ) ; 
85+ 
86+  result. push_str ( String :: from_utf8 ( msg_content) . unwrap ( ) . as_str ( ) ) ; 
87+ 
88+  result
89+  } 
90+  None  => String :: from ( "No Diagnostic" ) , 
91+  } ; 
92+ 
93+  insta:: with_settings!( { 
94+  prepend_module_to_snapshot => false , 
95+  } ,  { 
96+  insta:: assert_snapshot!( self . name,  content) ; 
97+ 
98+  } ) ; 
99+  } 
60100} 
61101
62102#[ sqlx:: test( migrator = "pgt_test_utils::MIGRATIONS" ) ]  
63- async  fn  invalid_column ( pool :  PgPool )  { 
64-  test ( 
65-  "invalid_column" , 
66-  "select id, unknown from contacts;" , 
67-  Some ( 
103+ async  fn  invalid_column ( test_db :  PgPool )  { 
104+  TestSetup   { 
105+  name :   "invalid_column" , 
106+  query :   "select id, unknown from contacts;" , 
107+  setup :   Some ( 
68108 r#" 
69109 create table public.contacts ( 
70110 id serial primary key, 
@@ -74,7 +114,66 @@ async fn invalid_column(pool: PgPool) {
74114 ); 
75115 "# , 
76116 ) , 
77-  & pool, 
78-  ) 
117+  test_db :  & test_db, 
118+  typed_identifiers :  vec ! [ ] , 
119+  } 
120+  . test ( ) 
121+  . await ; 
122+ } 
123+ 
124+ #[ sqlx:: test( migrator = "pgt_test_utils::MIGRATIONS" ) ]  
125+ async  fn  invalid_type_in_function ( test_db :  PgPool )  { 
126+  // create or replace function clean_up(uid uuid) 
127+  // returns void 
128+  // language sql 
129+  // as $$ 
130+  // delete from public.contacts where id = uid; 
131+  // $$; 
132+ 
133+  let  setup = r#" 
134+  create table public.contacts ( 
135+  id serial primary key, 
136+  name text not null, 
137+  is_vegetarian bool default false, 
138+  middle_name varchar(255) 
139+  ); 
140+  "# ; 
141+ 
142+  /* NOTE: The replaced type default value is *longer* than the param name. */ 
143+  TestSetup  { 
144+  name :  "invalid_type_in_function_longer_default" , 
145+  setup :  Some ( setup) , 
146+  query :  r#"delete from public.contacts where id = uid;"# , 
147+  test_db :  & test_db, 
148+  typed_identifiers :  vec ! [ TypedIdentifier  { 
149+  path:  "clean_up" . to_string( ) , 
150+  name:  Some ( "uid" . to_string( ) ) , 
151+  type_:  IdentifierType  { 
152+  schema:  None , 
153+  name:  "uuid" . to_string( ) , 
154+  is_array:  false , 
155+  } , 
156+  } ] , 
157+  } 
158+  . test ( ) 
159+  . await ; 
160+ 
161+  /* NOTE: The replaced type default value is *shorter* than the param name. */ 
162+  TestSetup  { 
163+  name :  "invalid_type_in_function_shorter_default" , 
164+  setup :  None , 
165+  query :  r#"delete from public.contacts where id = contact_name;"# , 
166+  test_db :  & test_db, 
167+  typed_identifiers :  vec ! [ TypedIdentifier  { 
168+  path:  "clean_up" . to_string( ) , 
169+  name:  Some ( "contact_name" . to_string( ) ) , 
170+  type_:  IdentifierType  { 
171+  schema:  None , 
172+  name:  "text" . to_string( ) , 
173+  is_array:  false , 
174+  } , 
175+  } ] , 
176+  } 
177+  . test ( ) 
79178 . await ; 
80179} 
0 commit comments