1- use pgt_console:: {
2- fmt:: { Formatter , HTML } ,
3- markup,
4- } ;
5- use pgt_diagnostics:: PrintDiagnostic ;
6- use pgt_typecheck:: { TypecheckParams , check_sql} ;
1+ use pgt_console:: fmt:: { Formatter , HTML } ;
2+ use pgt_diagnostics:: Diagnostic ;
3+ use pgt_typecheck:: { IdentifierType , TypecheckParams , TypedIdentifier , check_sql} ;
74use sqlx:: { Executor , PgPool } ;
5+ use std:: fmt:: Write ;
86
9- async fn test ( name : & str , query : & str , setup : Option < & str > , test_db : & PgPool ) {
10- if let Some ( setup ) = setup {
11- test_db
12- . execute ( setup )
13- . await
14- . expect ( "Failed to setup test database" ) ;
15- }
7+ struct TestSetup < ' a > {
8+ name : & ' a str ,
9+ query : & ' a str ,
10+ setup : Option < & ' a str > ,
11+ test_db : & ' a PgPool ,
12+ typed_identifiers : Vec < TypedIdentifier > ,
13+ }
1614
17- let mut parser = tree_sitter:: Parser :: new ( ) ;
18- parser
19- . set_language ( & pgt_treesitter_grammar:: LANGUAGE . into ( ) )
20- . expect ( "Error loading sql language" ) ;
21-
22- let schema_cache = pgt_schema_cache:: SchemaCache :: load ( test_db)
23- . await
24- . expect ( "Failed to load Schema Cache" ) ;
25-
26- let root = pgt_query:: parse ( query)
27- . unwrap ( )
28- . into_root ( )
29- . expect ( "Failed to parse query" ) ;
30- let tree = parser. parse ( query, None ) . unwrap ( ) ;
31-
32- let conn = & test_db;
33- let result = check_sql ( TypecheckParams {
34- conn,
35- sql : query,
36- ast : & root,
37- tree : & tree,
38- schema_cache : & schema_cache,
39- search_path_patterns : vec ! [ ] ,
40- identifiers : vec ! [ ] ,
41- } )
42- . await ;
15+ impl < ' a > TestSetup < ' a > {
16+ async fn test ( self ) {
17+ if let Some ( setup) = self . setup {
18+ self . test_db
19+ . execute ( setup)
20+ . await
21+ . expect ( "Failed to setup test selfbase" ) ;
22+ }
4323
44- let mut content = vec ! [ ] ;
45- let mut writer = HTML :: new ( & mut content) ;
24+ let mut parser = tree_sitter:: Parser :: new ( ) ;
25+ parser
26+ . set_language ( & pgt_treesitter_grammar:: LANGUAGE . into ( ) )
27+ . expect ( "Error loading sql language" ) ;
4628
47- Formatter :: new ( & mut writer)
48- . write_markup ( markup ! {
49- { PrintDiagnostic :: simple( & result. unwrap( ) . unwrap( ) ) }
29+ let schema_cache = pgt_schema_cache:: SchemaCache :: load ( & self . test_db )
30+ . await
31+ . expect ( "Failed to load Schema Cache" ) ;
32+
33+ let root = pgt_query:: parse ( self . query )
34+ . unwrap ( )
35+ . into_root ( )
36+ . expect ( "Failed to parse query" ) ;
37+ let tree = parser. parse ( self . query , None ) . unwrap ( ) ;
38+
39+ let result = check_sql ( TypecheckParams {
40+ conn : self . test_db ,
41+ sql : self . query ,
42+ ast : & root,
43+ tree : & tree,
44+ schema_cache : & schema_cache,
45+ identifiers : self . typed_identifiers ,
46+ search_path_patterns : vec ! [ ] ,
5047 } )
51- . unwrap ( ) ;
48+ . await ;
49+
50+ assert ! (
51+ result. is_ok( ) ,
52+ "Got Typechecking error: {}" ,
53+ result. unwrap_err( )
54+ ) ;
5255
53- let content = String :: from_utf8 ( content ) . unwrap ( ) ;
56+ let maybe_diagnostic = result . unwrap ( ) ;
5457
55- insta:: with_settings!( {
56- prepend_module_to_snapshot => false ,
57- } , {
58- insta:: assert_snapshot!( name, content) ;
59- } ) ;
58+ let content = match maybe_diagnostic {
59+ Some ( d) => {
60+ let mut result = String :: new ( ) ;
61+
62+ if let Some ( span) = d. location ( ) . span {
63+ for ( idx, c) in self . query . char_indices ( ) {
64+ if pgt_text_size:: TextSize :: new ( idx. try_into ( ) . unwrap ( ) ) == span. start ( ) {
65+ result. push_str ( "~~~" ) ;
66+ }
67+ if pgt_text_size:: TextSize :: new ( idx. try_into ( ) . unwrap ( ) ) == span. end ( ) {
68+ result. push_str ( "~~~" ) ;
69+ }
70+ result. push ( c) ;
71+ }
72+ } else {
73+ result. push_str ( self . query ) ;
74+ }
75+
76+ writeln ! ( & mut result) . unwrap ( ) ;
77+ writeln ! ( & mut result) . unwrap ( ) ;
78+
79+ let mut msg_content = vec ! [ ] ;
80+ let mut writer = HTML :: new ( & mut msg_content) ;
81+ let mut formatter = Formatter :: new ( & mut writer) ;
82+ d. message ( & mut formatter) . unwrap ( ) ;
83+
84+ result. push_str ( String :: from_utf8 ( msg_content) . unwrap ( ) . as_str ( ) ) ;
85+
86+ result
87+ }
88+ None => String :: from ( "No Diagnostic" ) ,
89+ } ;
90+
91+ insta:: with_settings!( {
92+ prepend_module_to_snapshot => false ,
93+ } , {
94+ insta:: assert_snapshot!( self . name, content) ;
95+
96+ } ) ;
97+ }
6098}
6199
62100#[ sqlx:: test( migrator = "pgt_test_utils::MIGRATIONS" ) ]
63- async fn invalid_column ( pool : PgPool ) {
64- test (
65- "invalid_column" ,
66- "select id, unknown from contacts;" ,
67- Some (
101+ async fn invalid_column ( test_db : PgPool ) {
102+ TestSetup {
103+ name : "invalid_column" ,
104+ query : "select id, unknown from contacts;" ,
105+ setup : Some (
68106 r#"
69107 create table public.contacts (
70108 id serial primary key,
@@ -74,7 +112,66 @@ async fn invalid_column(pool: PgPool) {
74112 );
75113 "# ,
76114 ) ,
77- & pool,
78- )
115+ test_db : & test_db,
116+ typed_identifiers : vec ! [ ] ,
117+ }
118+ . test ( )
119+ . await ;
120+ }
121+
122+ #[ sqlx:: test( migrator = "pgt_test_utils::MIGRATIONS" ) ]
123+ async fn invalid_type_in_function ( test_db : PgPool ) {
124+ // create or replace function clean_up(uid uuid)
125+ // returns void
126+ // language sql
127+ // as $$
128+ // delete from public.contacts where id = uid;
129+ // $$;
130+
131+ let setup = r#"
132+ create table public.contacts (
133+ id serial primary key,
134+ name text not null,
135+ is_vegetarian bool default false,
136+ middle_name varchar(255)
137+ );
138+ "# ;
139+
140+ /* NOTE: The replaced type default value is *longer* than the param name. */
141+ TestSetup {
142+ name : "invalid_type_in_function_longer_default" ,
143+ setup : Some ( setup) ,
144+ query : r#"delete from public.contacts where id = uid;"# ,
145+ test_db : & test_db,
146+ typed_identifiers : vec ! [ TypedIdentifier {
147+ path: "clean_up" . to_string( ) ,
148+ name: Some ( "uid" . to_string( ) ) ,
149+ type_: IdentifierType {
150+ schema: None ,
151+ name: "uuid" . to_string( ) ,
152+ is_array: false ,
153+ } ,
154+ } ] ,
155+ }
156+ . test ( )
157+ . await ;
158+
159+ /* NOTE: The replaced type default value is *longer* than the param name. */
160+ TestSetup {
161+ name : "invalid_type_in_function_shorter_default" ,
162+ setup : None ,
163+ query : r#"delete from public.contacts where id = contact_name;"# ,
164+ test_db : & test_db,
165+ typed_identifiers : vec ! [ TypedIdentifier {
166+ path: "clean_up" . to_string( ) ,
167+ name: Some ( "contact_name" . to_string( ) ) ,
168+ type_: IdentifierType {
169+ schema: None ,
170+ name: "text" . to_string( ) ,
171+ is_array: false ,
172+ } ,
173+ } ] ,
174+ }
175+ . test ( )
79176 . await ;
80177}
0 commit comments