Skip to content

Commit ced09e0

Browse files
authored
Support using both nullability and type overrides (launchbadge#549)
* Make it possible to use both nullability and type overrides * Fix override parsing lookahead logic * Update column override tests * Support nullability overrides with wildcard type overrides * Fix tests * Update query! overrides docs * Remove last bits of macro_result! * rustfmt
1 parent 116fbc1 commit ced09e0

File tree

8 files changed

+401
-95
lines changed

8 files changed

+401
-95
lines changed

sqlx-macros/src/migrate.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,10 @@ pub(crate) fn expand_migrator_from_dir(dir: LitStr) -> crate::Result<proc_macro2
8383
migrations.sort_by_key(|m| m.version);
8484

8585
Ok(quote! {
86-
macro_rules! macro_result {
87-
() => {
88-
sqlx::migrate::Migrator {
89-
migrations: std::borrow::Cow::Borrowed(&[
90-
#(#migrations),*
91-
])
92-
}
93-
}
86+
sqlx::migrate::Migrator {
87+
migrations: std::borrow::Cow::Borrowed(&[
88+
#(#migrations),*
89+
])
9490
}
9591
})
9692
}

sqlx-macros/src/query/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ where
223223
let record_name: Type = syn::parse_str("Record").unwrap();
224224

225225
for rust_col in &columns {
226-
if rust_col.type_.is_none() {
226+
if rust_col.type_.is_wildcard() {
227227
return Err(
228228
"columns may not have wildcard overrides in `query!()` or `query_as!()"
229229
.into(),

sqlx-macros/src/query/output.rs

Lines changed: 91 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use proc_macro2::{Ident, Span, TokenStream};
2-
use quote::{quote, ToTokens};
2+
use quote::{quote, ToTokens, TokenStreamExt};
33
use syn::Type;
44

55
use sqlx_core::column::Column;
@@ -14,7 +14,32 @@ use syn::Token;
1414

1515
pub struct RustColumn {
1616
pub(super) ident: Ident,
17-
pub(super) type_: Option<TokenStream>,
17+
pub(super) type_: ColumnType,
18+
}
19+
20+
pub(super) enum ColumnType {
21+
Exact(TokenStream),
22+
Wildcard,
23+
OptWildcard,
24+
}
25+
26+
impl ColumnType {
27+
pub(super) fn is_wildcard(&self) -> bool {
28+
match self {
29+
ColumnType::Exact(_) => false,
30+
_ => true,
31+
}
32+
}
33+
}
34+
35+
impl ToTokens for ColumnType {
36+
fn to_tokens(&self, tokens: &mut TokenStream) {
37+
tokens.append_all(match &self {
38+
ColumnType::Exact(type_) => type_.clone().into_iter(),
39+
ColumnType::Wildcard => quote! { _ }.into_iter(),
40+
ColumnType::OptWildcard => quote! { Option<_> }.into_iter(),
41+
})
42+
}
1843
}
1944

2045
struct DisplayColumn<'a> {
@@ -25,15 +50,25 @@ struct DisplayColumn<'a> {
2550

2651
struct ColumnDecl {
2752
ident: Ident,
28-
// TIL Rust still has OOP keywords like `abstract`, `final`, `override` and `virtual` reserved
29-
r#override: Option<ColumnOverride>,
53+
r#override: ColumnOverride,
3054
}
3155

32-
enum ColumnOverride {
56+
struct ColumnOverride {
57+
nullability: ColumnNullabilityOverride,
58+
type_: ColumnTypeOverride,
59+
}
60+
61+
#[derive(PartialEq)]
62+
enum ColumnNullabilityOverride {
3363
NonNull,
3464
Nullable,
35-
Wildcard,
65+
None,
66+
}
67+
68+
enum ColumnTypeOverride {
3669
Exact(Type),
70+
Wildcard,
71+
None,
3772
}
3873

3974
impl Display for DisplayColumn<'_> {
@@ -52,22 +87,30 @@ pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Resul
5287
let decl = ColumnDecl::parse(&column.name())
5388
.map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?;
5489

55-
let type_ = match decl.r#override {
56-
Some(ColumnOverride::Exact(ty)) => Some(ty.to_token_stream()),
57-
Some(ColumnOverride::Wildcard) => None,
58-
// these three could be combined but I prefer the clarity here
59-
Some(ColumnOverride::NonNull) => Some(get_column_type::<DB>(i, column)),
60-
Some(ColumnOverride::Nullable) => {
61-
let type_ = get_column_type::<DB>(i, column);
62-
Some(quote! { Option<#type_> })
90+
let ColumnOverride { nullability, type_ } = decl.r#override;
91+
92+
let nullable = match nullability {
93+
ColumnNullabilityOverride::NonNull => false,
94+
ColumnNullabilityOverride::Nullable => true,
95+
ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true),
96+
};
97+
let type_ = match (type_, nullable) {
98+
(ColumnTypeOverride::Exact(type_), false) => {
99+
ColumnType::Exact(type_.to_token_stream())
63100
}
64-
None => {
65-
let type_ = get_column_type::<DB>(i, column);
101+
(ColumnTypeOverride::Exact(type_), true) => {
102+
ColumnType::Exact(quote! { Option<#type_> })
103+
}
104+
105+
(ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard,
106+
(ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard,
66107

67-
if !describe.nullable(i).unwrap_or(true) {
68-
Some(type_)
108+
(ColumnTypeOverride::None, _) => {
109+
let type_ = get_column_type::<DB>(i, column);
110+
if !nullable {
111+
ColumnType::Exact(type_)
69112
} else {
70-
Some(quote! { Option<#type_> })
113+
ColumnType::Exact(quote! { Option<#type_> })
71114
}
72115
}
73116
};
@@ -97,14 +140,17 @@ pub fn quote_query_as<DB: DatabaseExt>(
97140
)| {
98141
match (input.checked, type_) {
99142
// we guarantee the type is valid so we can skip the runtime check
100-
(true, Some(type_)) => quote! {
143+
(true, ColumnType::Exact(type_)) => quote! {
101144
// binding to a `let` avoids confusing errors about
102145
// "try expression alternatives have incompatible types"
103146
// it doesn't seem to hurt inference in the other branches
104147
let #ident = row.try_get_unchecked::<#type_, _>(#i)?;
105148
},
106149
// type was overridden to be a wildcard so we fallback to the runtime check
107-
(true, None) => quote! ( let #ident = row.try_get(#i)?; ),
150+
(true, ColumnType::Wildcard) => quote! ( let #ident = row.try_get(#i)?; ),
151+
(true, ColumnType::OptWildcard) => {
152+
quote! ( let #ident = row.try_get::<Option<_>, _>(#i)?; )
153+
}
108154
// macro is the `_unchecked!()` variant so this will die in decoding if it's wrong
109155
(false, _) => quote!( let #ident = row.try_get_unchecked(#i)?; ),
110156
}
@@ -176,9 +222,12 @@ impl ColumnDecl {
176222
Ok(ColumnDecl {
177223
ident,
178224
r#override: if !remainder.is_empty() {
179-
Some(syn::parse_str(remainder)?)
225+
syn::parse_str(remainder)?
180226
} else {
181-
None
227+
ColumnOverride {
228+
nullability: ColumnNullabilityOverride::None,
229+
type_: ColumnTypeOverride::None,
230+
}
182231
},
183232
})
184233
}
@@ -188,27 +237,33 @@ impl Parse for ColumnOverride {
188237
fn parse(input: ParseStream) -> syn::Result<Self> {
189238
let lookahead = input.lookahead1();
190239

191-
if lookahead.peek(Token![:]) {
240+
let nullability = if lookahead.peek(Token![!]) {
241+
input.parse::<Token![!]>()?;
242+
243+
ColumnNullabilityOverride::NonNull
244+
} else if lookahead.peek(Token![?]) {
245+
input.parse::<Token![?]>()?;
246+
247+
ColumnNullabilityOverride::Nullable
248+
} else {
249+
ColumnNullabilityOverride::None
250+
};
251+
252+
let type_ = if input.lookahead1().peek(Token![:]) {
192253
input.parse::<Token![:]>()?;
193254

194255
let ty = Type::parse(input)?;
195256

196257
if let Type::Infer(_) = ty {
197-
Ok(ColumnOverride::Wildcard)
258+
ColumnTypeOverride::Wildcard
198259
} else {
199-
Ok(ColumnOverride::Exact(ty))
260+
ColumnTypeOverride::Exact(ty)
200261
}
201-
} else if lookahead.peek(Token![!]) {
202-
input.parse::<Token![!]>()?;
203-
204-
Ok(ColumnOverride::NonNull)
205-
} else if lookahead.peek(Token![?]) {
206-
input.parse::<Token![?]>()?;
207-
208-
Ok(ColumnOverride::Nullable)
209262
} else {
210-
Err(lookahead.error())
211-
}
263+
ColumnTypeOverride::None
264+
};
265+
266+
Ok(Self { nullability, type_ })
212267
}
213268
}
214269

src/macros.rs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@
215215
/// Selecting a column `foo as "foo: T"` (Postgres / SQLite) or `` foo as `foo: T` `` (MySQL)
216216
/// overrides the inferred type which is useful when selecting user-defined custom types
217217
/// (dynamic type checking is still done so if the types are incompatible this will be an error
218-
/// at runtime instead of compile-time):
218+
/// at runtime instead of compile-time). Note that this syntax alone doesn't override inferred nullability,
219+
/// but it is compatible with the forced not-null and forced nullable annotations:
219220
///
220221
/// ```rust,ignore
221222
/// # async fn main() {
@@ -227,15 +228,27 @@
227228
/// let my_int = MyInt4(1);
228229
///
229230
/// // Postgres/SQLite
230-
/// sqlx::query!(r#"select 1 as "id: MyInt4""#) // MySQL: use "select 1 as `id: MyInt4`" instead
231+
/// sqlx::query!(r#"select 1 as "id!: MyInt4""#) // MySQL: use "select 1 as `id: MyInt4`" instead
231232
/// .fetch_one(&mut conn)
232233
/// .await?;
233234
///
234235
/// // For Postgres this would have been inferred to be `Option<i32>`, MySQL/SQLite `i32`
236+
/// // Note that while using `id: MyInt4` (without the `!`) would work the same for MySQL/SQLite,
237+
/// // Postgres would expect `Some(MyInt4(1))` and the code wouldn't compile
235238
/// assert_eq!(record.id, MyInt4(1));
236239
/// # }
237240
/// ```
238241
///
242+
/// ##### Overrides cheatsheet
243+
///
244+
/// | Syntax | Nullability | Type |
245+
/// | --------- | --------------- | ---------- |
246+
/// | `foo!` | Forced not-null | Inferred |
247+
/// | `foo?` | Forced nullable | Inferred |
248+
/// | `foo: T` | Inferred | Overridden |
249+
/// | `foo!: T` | Forced not-null | Overridden |
250+
/// | `foo?: T` | Forced nullable | Overridden |
251+
///
239252
/// ## Offline Mode (requires the `offline` feature)
240253
/// The macros can be configured to not require a live database connection for compilation,
241254
/// but it requires a couple extra steps:
@@ -601,18 +614,10 @@ macro_rules! query_file_as_unchecked (
601614
#[macro_export]
602615
macro_rules! migrate {
603616
($dir:literal) => {{
604-
#[macro_use]
605-
mod _macro_result {
606-
$crate::sqlx_macros::migrate!($dir);
607-
}
608-
macro_result!()
617+
$crate::sqlx_macros::migrate!($dir)
609618
}};
610619

611620
() => {{
612-
#[macro_use]
613-
mod _macro_result {
614-
$crate::sqlx_macros::migrate!("migrations");
615-
}
616-
macro_result!()
621+
$crate::sqlx_macros::migrate!("migrations")
617622
}};
618623
}

tests/any/pool.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
use futures::{FutureExt, TryFutureExt};
21
use sqlx::any::AnyPoolOptions;
3-
use sqlx::prelude::*;
4-
use sqlx_core::any::AnyPool;
5-
use sqlx_test::new;
62
use std::sync::{
73
atomic::{AtomicUsize, Ordering},
84
Arc,
@@ -16,7 +12,7 @@ async fn pool_should_invoke_after_connect() -> anyhow::Result<()> {
1612
let pool = AnyPoolOptions::new()
1713
.after_connect({
1814
let counter = counter.clone();
19-
move |conn| {
15+
move |_conn| {
2016
let counter = counter.clone();
2117
Box::pin(async move {
2218
counter.fetch_add(1, Ordering::SeqCst);

0 commit comments

Comments
 (0)