Skip to content

Commit 22b4e54

Browse files
kewl
1 parent 36345c6 commit 22b4e54

File tree

5 files changed

+77
-18
lines changed

5 files changed

+77
-18
lines changed

crates/pgt_typecheck/src/diagnostics.rs

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ use std::io;
22

33
use pgt_console::markup;
44
use pgt_diagnostics::{Advices, Diagnostic, LogCategory, MessageAndDescription, Severity, Visit};
5-
use pgt_text_size::{TextRange, TextRangeReplacement, TextSize};
5+
use pgt_text_size::{TextRange, TextSize};
66
use sqlx::postgres::{PgDatabaseError, PgSeverity};
77

8+
use crate::{IdentifierType, TypedReplacement};
9+
810
/// A specialized diagnostic for the typechecker.
911
///
1012
/// Type diagnostics are always **errors**.
@@ -94,18 +96,50 @@ impl Advices for TypecheckAdvices {
9496
}
9597
}
9698

99+
/// Finds the original type at the given position in the adjusted text
100+
fn find_type_at_position(
101+
adjusted_position: TextSize,
102+
type_info: &[(TextRange, IdentifierType)],
103+
) -> Option<&IdentifierType> {
104+
type_info
105+
.iter()
106+
.find(|(range, _)| range.contains(adjusted_position))
107+
.map(|(_, type_)| type_)
108+
}
109+
110+
/// Rewrites error messages to show the original type name instead of the replaced literal value
111+
fn rewrite_error_message(original_message: &str, identifier_type: &IdentifierType) -> String {
112+
// pattern: invalid input syntax for type X: "literal_value"
113+
// we want to replace "literal_value" with the type name
114+
115+
if let Some(colon_pos) = original_message.rfind(": ") {
116+
let before_value = &original_message[..colon_pos];
117+
118+
// build the type name, including schema if present
119+
let type_name = if let Some(schema) = &identifier_type.schema {
120+
format!("{}.{}", schema, identifier_type.name)
121+
} else {
122+
identifier_type.name.clone()
123+
};
124+
125+
format!("{}: {}", before_value, type_name)
126+
} else {
127+
original_message.to_string()
128+
}
129+
}
130+
97131
pub(crate) fn create_type_error(
98132
pg_err: &PgDatabaseError,
99133
ts: &tree_sitter::Tree,
100-
txt_replacement: TextRangeReplacement,
134+
typed_replacement: TypedReplacement,
101135
) -> TypecheckDiagnostic {
102136
let position = pg_err.position().and_then(|pos| match pos {
103137
sqlx::postgres::PgErrorPosition::Original(pos) => Some(pos - 1),
104138
_ => None,
105139
});
106140

107141
let range = position.and_then(|pos| {
108-
let adjusted = txt_replacement.to_original_position(TextSize::new(pos.try_into().unwrap()));
142+
let adjusted = typed_replacement.replacement.to_original_position(TextSize::new(pos.try_into().unwrap()));
109143

110144
ts.root_node()
111145
.named_descendant_for_byte_range(adjusted.into(), adjusted.into())
@@ -128,8 +162,20 @@ pub(crate) fn create_type_error(
128162
PgSeverity::Log => Severity::Information,
129163
};
130164

165+
// check if the error position corresponds to a replaced parameter
166+
let message = if let Some(pos) = position {
167+
let adjusted_pos = TextSize::new(pos.try_into().unwrap());
168+
if let Some(original_type) = find_type_at_position(adjusted_pos, &typed_replacement.type_info) {
169+
rewrite_error_message(pg_err.message(), original_type)
170+
} else {
171+
pg_err.to_string()
172+
}
173+
} else {
174+
pg_err.to_string()
175+
};
176+
131177
TypecheckDiagnostic {
132-
message: pg_err.to_string().into(),
178+
message: message.into(),
133179
severity,
134180
span: range,
135181
advices: TypecheckAdvices {

crates/pgt_typecheck/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use sqlx::postgres::PgDatabaseError;
1010
pub use sqlx::postgres::PgSeverity;
1111
use sqlx::{Executor, PgPool};
1212
use typed_identifier::apply_identifiers;
13-
pub use typed_identifier::{IdentifierType, TypedIdentifier};
13+
pub use typed_identifier::{IdentifierType, TypedIdentifier, TypedReplacement};
1414

1515
#[derive(Debug)]
1616
pub struct TypecheckParams<'a> {
@@ -48,7 +48,7 @@ pub async fn check_sql(
4848
// each typecheck operation.
4949
conn.close_on_drop();
5050

51-
let replacement = apply_identifiers(
51+
let typed_replacement = apply_identifiers(
5252
params.identifiers,
5353
params.schema_cache,
5454
params.tree,
@@ -68,13 +68,13 @@ pub async fn check_sql(
6868
conn.execute(&*search_path_query).await?;
6969
}
7070

71-
let res = conn.prepare(replacement.text()).await;
71+
let res = conn.prepare(typed_replacement.replacement.text()).await;
7272

7373
match res {
7474
Ok(_) => Ok(None),
7575
Err(sqlx::Error::Database(err)) => {
7676
let pg_err = err.downcast_ref::<PgDatabaseError>();
77-
Ok(Some(create_type_error(pg_err, params.tree, replacement)))
77+
Ok(Some(create_type_error(pg_err, params.tree, typed_replacement)))
7878
}
7979
Err(err) => Err(err),
8080
}

crates/pgt_typecheck/src/typed_identifier.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use pgt_schema_cache::PostgresType;
2-
use pgt_text_size::{TextRangeReplacement, TextRangeReplacementBuilder};
2+
use pgt_text_size::{TextRange, TextRangeReplacement, TextRangeReplacementBuilder};
33
use pgt_treesitter::queries::{ParameterMatch, TreeSitterQueriesExecutor};
44

55
/// It is used to replace parameters within the SQL string.
@@ -21,13 +21,20 @@ pub struct IdentifierType {
2121
pub is_array: bool,
2222
}
2323

24+
/// Contains the text replacement along with metadata about which ranges correspond to which types.
25+
#[derive(Debug)]
26+
pub struct TypedReplacement {
27+
pub replacement: TextRangeReplacement,
28+
pub type_info: Vec<(TextRange, IdentifierType)>,
29+
}
30+
2431
/// Applies the identifiers to the SQL string by replacing them with their default values.
2532
pub fn apply_identifiers<'a>(
2633
identifiers: Vec<TypedIdentifier>,
2734
schema_cache: &'a pgt_schema_cache::SchemaCache,
2835
cst: &'a tree_sitter::Tree,
2936
sql: &'a str,
30-
) -> TextRangeReplacement {
37+
) -> TypedReplacement {
3138
let mut executor = TreeSitterQueriesExecutor::new(cst.root_node(), sql);
3239

3340
executor.add_query_results::<ParameterMatch>();
@@ -46,20 +53,26 @@ pub fn apply_identifiers<'a>(
4653
// Resolve the type based on whether we're accessing a field of a composite type
4754
let type_ = resolve_type(identifier, position, &parts, schema_cache)?;
4855

49-
Some((m.get_byte_range(), type_, identifier.type_.is_array))
56+
Some((m.get_byte_range(), type_, identifier.type_.clone()))
5057
})
5158
.collect();
5259

5360
let mut text_range_replacement_builder = TextRangeReplacementBuilder::new(sql);
61+
let mut type_info = vec![];
5462

55-
for (range, type_, is_array) in replacements {
56-
let default_value = get_formatted_default_value(type_, is_array);
63+
for (range, type_, original_type) in replacements {
64+
let default_value = get_formatted_default_value(type_, original_type.is_array);
5765

5866
text_range_replacement_builder
5967
.replace_range(range.clone().try_into().unwrap(), &default_value);
68+
69+
type_info.push((range.try_into().unwrap(), original_type));
6070
}
6171

62-
text_range_replacement_builder.build()
72+
TypedReplacement {
73+
replacement: text_range_replacement_builder.build(),
74+
type_info,
75+
}
6376
}
6477

6578
/// Format the default value based on the type and whether it's an array
@@ -307,7 +320,7 @@ mod tests {
307320
let replacement = super::apply_identifiers(identifiers, &schema_cache, &tree, input);
308321

309322
assert_eq!(
310-
replacement.text(),
323+
replacement.replacement.text(),
311324
// the numeric parameters are filled with 0;
312325
"select 0 + 0 + 0 + 0 + 0 + 'critical'"
313326
);
@@ -357,7 +370,7 @@ mod tests {
357370
let replacement = super::apply_identifiers(identifiers, &schema_cache, &tree, input);
358371

359372
assert_eq!(
360-
replacement.text(),
373+
replacement.replacement.text(),
361374
r#"select id from auth.users where email_change_confirm_status = '00000000-0000-0000-0000-000000000000' and email = '';"#
362375
);
363376
}

crates/pgt_typecheck/tests/snapshots/invalid_type_in_function_longer_default.snap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ expression: content
44
---
55
delete from public.contacts where id = ~~~uid~~~;
66

7-
invalid input syntax for type integer: &quot;00000000-0000-0000-0000-000000000000&quot;
7+
invalid input syntax for type integer: uuid

crates/pgt_typecheck/tests/snapshots/invalid_type_in_function_shorter_default.snap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ expression: content
44
---
55
delete from public.contacts where id = ~~~contact_name~~~;
66

7-
invalid input syntax for type integer: &quot;&quot;
7+
invalid input syntax for type integer: text

0 commit comments

Comments
 (0)