1
1
use proc_macro2:: { Ident , Span , TokenStream } ;
2
- use quote:: { quote, ToTokens } ;
2
+ use quote:: { quote, ToTokens , TokenStreamExt } ;
3
3
use syn:: Type ;
4
4
5
5
use sqlx_core:: column:: Column ;
@@ -14,7 +14,32 @@ use syn::Token;
14
14
15
15
pub struct RustColumn {
16
16
pub ( super ) ident : Ident ,
17
- pub ( super ) type_ : Option < TokenStream > ,
17
+ pub ( super ) type_ : ColumnType ,
18
+ }
19
+
20
+ pub ( super ) enum ColumnType {
21
+ Exact ( TokenStream ) ,
22
+ Wildcard ,
23
+ OptWildcard ,
24
+ }
25
+
26
+ impl ColumnType {
27
+ pub ( super ) fn is_wildcard ( & self ) -> bool {
28
+ match self {
29
+ ColumnType :: Exact ( _) => false ,
30
+ _ => true ,
31
+ }
32
+ }
33
+ }
34
+
35
+ impl ToTokens for ColumnType {
36
+ fn to_tokens ( & self , tokens : & mut TokenStream ) {
37
+ tokens. append_all ( match & self {
38
+ ColumnType :: Exact ( type_) => type_. clone ( ) . into_iter ( ) ,
39
+ ColumnType :: Wildcard => quote ! { _ } . into_iter ( ) ,
40
+ ColumnType :: OptWildcard => quote ! { Option <_> } . into_iter ( ) ,
41
+ } )
42
+ }
18
43
}
19
44
20
45
struct DisplayColumn < ' a > {
@@ -25,15 +50,25 @@ struct DisplayColumn<'a> {
25
50
26
51
struct ColumnDecl {
27
52
ident : Ident ,
28
- // TIL Rust still has OOP keywords like `abstract`, `final`, `override` and `virtual` reserved
29
- r#override : Option < ColumnOverride > ,
53
+ r#override : ColumnOverride ,
30
54
}
31
55
32
- enum ColumnOverride {
56
+ struct ColumnOverride {
57
+ nullability : ColumnNullabilityOverride ,
58
+ type_ : ColumnTypeOverride ,
59
+ }
60
+
61
+ #[ derive( PartialEq ) ]
62
+ enum ColumnNullabilityOverride {
33
63
NonNull ,
34
64
Nullable ,
35
- Wildcard ,
65
+ None ,
66
+ }
67
+
68
+ enum ColumnTypeOverride {
36
69
Exact ( Type ) ,
70
+ Wildcard ,
71
+ None ,
37
72
}
38
73
39
74
impl Display for DisplayColumn < ' _ > {
@@ -52,22 +87,30 @@ pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Resul
52
87
let decl = ColumnDecl :: parse ( & column. name ( ) )
53
88
. map_err ( |e| format ! ( "column name {:?} is invalid: {}" , column. name( ) , e) ) ?;
54
89
55
- let type_ = match decl. r#override {
56
- Some ( ColumnOverride :: Exact ( ty) ) => Some ( ty. to_token_stream ( ) ) ,
57
- Some ( ColumnOverride :: Wildcard ) => None ,
58
- // these three could be combined but I prefer the clarity here
59
- Some ( ColumnOverride :: NonNull ) => Some ( get_column_type :: < DB > ( i, column) ) ,
60
- Some ( ColumnOverride :: Nullable ) => {
61
- let type_ = get_column_type :: < DB > ( i, column) ;
62
- Some ( quote ! { Option <#type_> } )
90
+ let ColumnOverride { nullability, type_ } = decl. r#override ;
91
+
92
+ let nullable = match nullability {
93
+ ColumnNullabilityOverride :: NonNull => false ,
94
+ ColumnNullabilityOverride :: Nullable => true ,
95
+ ColumnNullabilityOverride :: None => describe. nullable ( i) . unwrap_or ( true ) ,
96
+ } ;
97
+ let type_ = match ( type_, nullable) {
98
+ ( ColumnTypeOverride :: Exact ( type_) , false ) => {
99
+ ColumnType :: Exact ( type_. to_token_stream ( ) )
63
100
}
64
- None => {
65
- let type_ = get_column_type :: < DB > ( i, column) ;
101
+ ( ColumnTypeOverride :: Exact ( type_) , true ) => {
102
+ ColumnType :: Exact ( quote ! { Option <#type_> } )
103
+ }
104
+
105
+ ( ColumnTypeOverride :: Wildcard , false ) => ColumnType :: Wildcard ,
106
+ ( ColumnTypeOverride :: Wildcard , true ) => ColumnType :: OptWildcard ,
66
107
67
- if !describe. nullable ( i) . unwrap_or ( true ) {
68
- Some ( type_)
108
+ ( ColumnTypeOverride :: None , _) => {
109
+ let type_ = get_column_type :: < DB > ( i, column) ;
110
+ if !nullable {
111
+ ColumnType :: Exact ( type_)
69
112
} else {
70
- Some ( quote ! { Option <#type_> } )
113
+ ColumnType :: Exact ( quote ! { Option <#type_> } )
71
114
}
72
115
}
73
116
} ;
@@ -97,14 +140,17 @@ pub fn quote_query_as<DB: DatabaseExt>(
97
140
) | {
98
141
match ( input. checked , type_) {
99
142
// we guarantee the type is valid so we can skip the runtime check
100
- ( true , Some ( type_) ) => quote ! {
143
+ ( true , ColumnType :: Exact ( type_) ) => quote ! {
101
144
// binding to a `let` avoids confusing errors about
102
145
// "try expression alternatives have incompatible types"
103
146
// it doesn't seem to hurt inference in the other branches
104
147
let #ident = row. try_get_unchecked:: <#type_, _>( #i) ?;
105
148
} ,
106
149
// type was overridden to be a wildcard so we fallback to the runtime check
107
- ( true , None ) => quote ! ( let #ident = row. try_get( #i) ?; ) ,
150
+ ( true , ColumnType :: Wildcard ) => quote ! ( let #ident = row. try_get( #i) ?; ) ,
151
+ ( true , ColumnType :: OptWildcard ) => {
152
+ quote ! ( let #ident = row. try_get:: <Option <_>, _>( #i) ?; )
153
+ }
108
154
// macro is the `_unchecked!()` variant so this will die in decoding if it's wrong
109
155
( false , _) => quote ! ( let #ident = row. try_get_unchecked( #i) ?; ) ,
110
156
}
@@ -176,9 +222,12 @@ impl ColumnDecl {
176
222
Ok ( ColumnDecl {
177
223
ident,
178
224
r#override : if !remainder. is_empty ( ) {
179
- Some ( syn:: parse_str ( remainder) ?)
225
+ syn:: parse_str ( remainder) ?
180
226
} else {
181
- None
227
+ ColumnOverride {
228
+ nullability : ColumnNullabilityOverride :: None ,
229
+ type_ : ColumnTypeOverride :: None ,
230
+ }
182
231
} ,
183
232
} )
184
233
}
@@ -188,27 +237,33 @@ impl Parse for ColumnOverride {
188
237
fn parse ( input : ParseStream ) -> syn:: Result < Self > {
189
238
let lookahead = input. lookahead1 ( ) ;
190
239
191
- if lookahead. peek ( Token ! [ : ] ) {
240
+ let nullability = if lookahead. peek ( Token ! [ !] ) {
241
+ input. parse :: < Token ! [ !] > ( ) ?;
242
+
243
+ ColumnNullabilityOverride :: NonNull
244
+ } else if lookahead. peek ( Token ! [ ?] ) {
245
+ input. parse :: < Token ! [ ?] > ( ) ?;
246
+
247
+ ColumnNullabilityOverride :: Nullable
248
+ } else {
249
+ ColumnNullabilityOverride :: None
250
+ } ;
251
+
252
+ let type_ = if input. lookahead1 ( ) . peek ( Token ! [ : ] ) {
192
253
input. parse :: < Token ! [ : ] > ( ) ?;
193
254
194
255
let ty = Type :: parse ( input) ?;
195
256
196
257
if let Type :: Infer ( _) = ty {
197
- Ok ( ColumnOverride :: Wildcard )
258
+ ColumnTypeOverride :: Wildcard
198
259
} else {
199
- Ok ( ColumnOverride :: Exact ( ty) )
260
+ ColumnTypeOverride :: Exact ( ty)
200
261
}
201
- } else if lookahead. peek ( Token ! [ !] ) {
202
- input. parse :: < Token ! [ !] > ( ) ?;
203
-
204
- Ok ( ColumnOverride :: NonNull )
205
- } else if lookahead. peek ( Token ! [ ?] ) {
206
- input. parse :: < Token ! [ ?] > ( ) ?;
207
-
208
- Ok ( ColumnOverride :: Nullable )
209
262
} else {
210
- Err ( lookahead. error ( ) )
211
- }
263
+ ColumnTypeOverride :: None
264
+ } ;
265
+
266
+ Ok ( Self { nullability, type_ } )
212
267
}
213
268
}
214
269
0 commit comments