Skip to content

Commit b2ec636

Browse files
very good, very good
1 parent 22b4e54 commit b2ec636

File tree

7 files changed

+146
-65
lines changed

7 files changed

+146
-65
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/pgt_typecheck/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ version = "0.0.0"
1414
[dependencies]
1515
globset = "0.4.16"
1616
itertools = { version = "0.14.0" }
17+
once_cell = "1.20.2"
1718
pgt_console.workspace = true
1819
pgt_diagnostics.workspace = true
1920
pgt_query.workspace = true
2021
pgt_schema_cache.workspace = true
2122
pgt_text_size.workspace = true
2223
pgt_treesitter.workspace = true
2324
pgt_treesitter_grammar.workspace = true
25+
regex = "1.11.1"
2426
sqlx.workspace = true
2527
tokio.workspace = true
2628
tree-sitter.workspace = true

crates/pgt_typecheck/src/diagnostics.rs

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
use std::io;
22

3+
use once_cell::sync::Lazy;
34
use pgt_console::markup;
45
use pgt_diagnostics::{Advices, Diagnostic, LogCategory, MessageAndDescription, Severity, Visit};
56
use pgt_text_size::{TextRange, TextSize};
7+
use regex::Regex;
68
use sqlx::postgres::{PgDatabaseError, PgSeverity};
79

8-
use crate::{IdentifierType, TypedReplacement};
10+
use crate::typed_identifier::{IdentifierReplacement, TypedReplacement};
911

1012
/// A specialized diagnostic for the typechecker.
1113
///
@@ -96,36 +98,65 @@ impl Advices for TypecheckAdvices {
9698
}
9799
}
98100

99-
/// Finds the original type at the given position in the adjusted text
100-
fn find_type_at_position(
101-
adjusted_position: TextSize,
102-
type_info: &[(TextRange, IdentifierType)],
103-
) -> Option<&IdentifierType> {
104-
type_info
105-
.iter()
106-
.find(|(range, _)| range.contains(adjusted_position))
107-
.map(|(_, type_)| type_)
101+
/// Pattern and rewrite rule for error messages
102+
struct ErrorRewriteRule {
103+
pattern: Regex,
104+
rewrite: fn(&regex::Captures, &IdentifierReplacement) -> String,
108105
}
109106

110-
/// Rewrites error messages to show the original type name instead of the replaced literal value
111-
fn rewrite_error_message(original_message: &str, identifier_type: &IdentifierType) -> String {
112-
// pattern: invalid input syntax for type X: "literal_value"
113-
// we want to replace "literal_value" with the type name
114-
115-
if let Some(colon_pos) = original_message.rfind(": ") {
116-
let before_value = &original_message[..colon_pos];
117-
118-
// build the type name, including schema if present
119-
let type_name = if let Some(schema) = &identifier_type.schema {
120-
format!("{}.{}", schema, identifier_type.name)
121-
} else {
122-
identifier_type.name.clone()
123-
};
107+
static ERROR_RULES: Lazy<Vec<ErrorRewriteRule>> = Lazy::new(|| {
108+
vec![
109+
ErrorRewriteRule {
110+
pattern: Regex::new(r#"invalid input syntax for type (\w+): "([^"]*)""#).unwrap(),
111+
rewrite: |caps, replacement| {
112+
let expected_type = &caps[1];
113+
format!(
114+
"`{}` is of type {}, not {}",
115+
replacement.original_name, replacement.type_name, expected_type
116+
)
117+
},
118+
},
119+
ErrorRewriteRule {
120+
pattern: Regex::new(
121+
r#"column "([^"]*)" is of type (\w+) but expression is of type (\w+)"#,
122+
)
123+
.unwrap(),
124+
rewrite: |caps, replacement| {
125+
let column = &caps[1];
126+
let expected_type = &caps[2];
127+
format!(
128+
"column `{}` expects {}, but `{}` is of type {}",
129+
column, expected_type, replacement.original_name, replacement.type_name
130+
)
131+
},
132+
},
133+
ErrorRewriteRule {
134+
pattern: Regex::new(r#"operator does not exist: (.+)"#).unwrap(),
135+
rewrite: |caps, replacement| {
136+
let operator_expr = &caps[1];
137+
format!(
138+
"operator does not exist: {} (parameter `{}` is of type {})",
139+
operator_expr, replacement.original_name, replacement.type_name
140+
)
141+
},
142+
},
143+
]
144+
});
124145

125-
format!("{}: {}", before_value, type_name)
126-
} else {
127-
original_message.to_string()
146+
/// Rewrites Postgres error messages to be more user-friendly
147+
fn rewrite_error_message(pg_error_message: &str, replacement: &IdentifierReplacement) -> String {
148+
// try each rule
149+
for rule in ERROR_RULES.iter() {
150+
if let Some(caps) = rule.pattern.captures(pg_error_message) {
151+
return (rule.rewrite)(&caps, replacement);
152+
}
128153
}
154+
155+
// fallback: generic value replacement
156+
let unquoted_default = replacement.default_value.trim_matches('\'');
157+
pg_error_message
158+
.replace(&format!("\"{}\"", unquoted_default), &replacement.type_name)
159+
.replace(&format!("'{}'", unquoted_default), &replacement.type_name)
129160
}
130161

131162
pub(crate) fn create_type_error(
@@ -138,11 +169,17 @@ pub(crate) fn create_type_error(
138169
_ => None,
139170
});
140171

141-
let range = position.and_then(|pos| {
142-
let adjusted = typed_replacement.replacement.to_original_position(TextSize::new(pos.try_into().unwrap()));
172+
let original_position = position.map(|p| {
173+
let pos = TextSize::new(p.try_into().unwrap());
174+
175+
typed_replacement
176+
.text_replacement()
177+
.to_original_position(pos)
178+
});
143179

180+
let range = original_position.and_then(|pos| {
144181
ts.root_node()
145-
.named_descendant_for_byte_range(adjusted.into(), adjusted.into())
182+
.named_descendant_for_byte_range(pos.into(), pos.into())
146183
.map(|node| {
147184
TextRange::new(
148185
node.start_byte().try_into().unwrap(),
@@ -162,11 +199,9 @@ pub(crate) fn create_type_error(
162199
PgSeverity::Log => Severity::Information,
163200
};
164201

165-
// check if the error position corresponds to a replaced parameter
166-
let message = if let Some(pos) = position {
167-
let adjusted_pos = TextSize::new(pos.try_into().unwrap());
168-
if let Some(original_type) = find_type_at_position(adjusted_pos, &typed_replacement.type_info) {
169-
rewrite_error_message(pg_err.message(), original_type)
202+
let message = if let Some(pos) = original_position {
203+
if let Some(replacement) = typed_replacement.find_type_at_position(pos) {
204+
rewrite_error_message(pg_err.message(), replacement)
170205
} else {
171206
pg_err.to_string()
172207
}

crates/pgt_typecheck/src/lib.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use sqlx::postgres::PgDatabaseError;
1010
pub use sqlx::postgres::PgSeverity;
1111
use sqlx::{Executor, PgPool};
1212
use typed_identifier::apply_identifiers;
13-
pub use typed_identifier::{IdentifierType, TypedIdentifier, TypedReplacement};
13+
pub use typed_identifier::{IdentifierType, TypedIdentifier};
1414

1515
#[derive(Debug)]
1616
pub struct TypecheckParams<'a> {
@@ -68,13 +68,19 @@ pub async fn check_sql(
6868
conn.execute(&*search_path_query).await?;
6969
}
7070

71-
let res = conn.prepare(typed_replacement.replacement.text()).await;
71+
let res = conn
72+
.prepare(typed_replacement.text_replacement().text())
73+
.await;
7274

7375
match res {
7476
Ok(_) => Ok(None),
7577
Err(sqlx::Error::Database(err)) => {
7678
let pg_err = err.downcast_ref::<PgDatabaseError>();
77-
Ok(Some(create_type_error(pg_err, params.tree, typed_replacement)))
79+
Ok(Some(create_type_error(
80+
pg_err,
81+
params.tree,
82+
typed_replacement,
83+
)))
7884
}
7985
Err(err) => Err(err),
8086
}

crates/pgt_typecheck/src/typed_identifier.rs

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use pgt_schema_cache::PostgresType;
2-
use pgt_text_size::{TextRange, TextRangeReplacement, TextRangeReplacementBuilder};
2+
use pgt_text_size::{TextRange, TextRangeReplacement, TextRangeReplacementBuilder, TextSize};
33
use pgt_treesitter::queries::{ParameterMatch, TreeSitterQueriesExecutor};
44

55
/// It is used to replace parameters within the SQL string.
@@ -21,11 +21,52 @@ pub struct IdentifierType {
2121
pub is_array: bool,
2222
}
2323

24+
#[derive(Debug)]
25+
pub(crate) struct IdentifierReplacement {
26+
pub original_name: String,
27+
pub original_range: std::ops::Range<usize>,
28+
/// The default value with which the identifier was replaced, e.g. `''` for a TEXT param.
29+
pub default_value: String,
30+
pub type_name: String,
31+
}
32+
2433
/// Contains the text replacement along with metadata about which ranges correspond to which types.
2534
#[derive(Debug)]
26-
pub struct TypedReplacement {
27-
pub replacement: TextRangeReplacement,
28-
pub type_info: Vec<(TextRange, IdentifierType)>,
35+
pub(crate) struct TypedReplacement {
36+
text_replacement: TextRangeReplacement,
37+
identifier_replacements: Vec<IdentifierReplacement>,
38+
}
39+
40+
impl TypedReplacement {
41+
pub fn new(sql: &str, replacements: Vec<IdentifierReplacement>) -> Self {
42+
let mut text_range_replacement_builder = TextRangeReplacementBuilder::new(sql);
43+
44+
for replacement in &replacements {
45+
let text_range: TextRange = replacement.original_range.clone().try_into().unwrap();
46+
text_range_replacement_builder.replace_range(text_range, &replacement.default_value);
47+
}
48+
49+
Self {
50+
identifier_replacements: replacements,
51+
text_replacement: text_range_replacement_builder.build(),
52+
}
53+
}
54+
55+
/// Finds the original type at the given position in the adjusted text
56+
pub(crate) fn find_type_at_position(
57+
&self,
58+
original_position: TextSize,
59+
) -> Option<&IdentifierReplacement> {
60+
self.identifier_replacements.iter().find(|replacement| {
61+
replacement
62+
.original_range
63+
.contains(&original_position.try_into().unwrap())
64+
})
65+
}
66+
67+
pub(crate) fn text_replacement(&self) -> &TextRangeReplacement {
68+
&self.text_replacement
69+
}
2970
}
3071

3172
/// Applies the identifiers to the SQL string by replacing them with their default values.
@@ -40,7 +81,7 @@ pub fn apply_identifiers<'a>(
4081
executor.add_query_results::<ParameterMatch>();
4182

4283
// Collect all replacements first to avoid modifying the string while iterating
43-
let replacements: Vec<_> = executor
84+
let replacements: Vec<IdentifierReplacement> = executor
4485
.get_iter(None)
4586
.filter_map(|q| {
4687
let m: &ParameterMatch = q.try_into().ok()?;
@@ -51,28 +92,23 @@ pub fn apply_identifiers<'a>(
5192
let (identifier, position) = find_matching_identifier(&parts, &identifiers)?;
5293

5394
// Resolve the type based on whether we're accessing a field of a composite type
54-
let type_ = resolve_type(identifier, position, &parts, schema_cache)?;
55-
56-
Some((m.get_byte_range(), type_, identifier.type_.clone()))
57-
})
58-
.collect();
95+
let postgres_type = resolve_type(identifier, position, &parts, schema_cache)?;
5996

60-
let mut text_range_replacement_builder = TextRangeReplacementBuilder::new(sql);
61-
let mut type_info = vec![];
97+
let default_value =
98+
get_formatted_default_value(postgres_type, identifier.type_.is_array);
6299

63-
for (range, type_, original_type) in replacements {
64-
let default_value = get_formatted_default_value(type_, original_type.is_array);
100+
let replacement = IdentifierReplacement {
101+
default_value,
102+
original_name: identifier.name.clone().unwrap_or("".into()),
103+
original_range: m.get_byte_range(),
104+
type_name: identifier.type_.name.clone(),
105+
};
65106

66-
text_range_replacement_builder
67-
.replace_range(range.clone().try_into().unwrap(), &default_value);
68-
69-
type_info.push((range.try_into().unwrap(), original_type));
70-
}
107+
Some(replacement)
108+
})
109+
.collect();
71110

72-
TypedReplacement {
73-
replacement: text_range_replacement_builder.build(),
74-
type_info,
75-
}
111+
TypedReplacement::new(sql, replacements)
76112
}
77113

78114
/// Format the default value based on the type and whether it's an array
@@ -320,7 +356,7 @@ mod tests {
320356
let replacement = super::apply_identifiers(identifiers, &schema_cache, &tree, input);
321357

322358
assert_eq!(
323-
replacement.replacement.text(),
359+
replacement.text_replacement.text(),
324360
// the numeric parameters are filled with 0;
325361
"select 0 + 0 + 0 + 0 + 0 + 'critical'"
326362
);
@@ -370,7 +406,7 @@ mod tests {
370406
let replacement = super::apply_identifiers(identifiers, &schema_cache, &tree, input);
371407

372408
assert_eq!(
373-
replacement.replacement.text(),
409+
replacement.text_replacement.text(),
374410
r#"select id from auth.users where email_change_confirm_status = '00000000-0000-0000-0000-000000000000' and email = '';"#
375411
);
376412
}

crates/pgt_typecheck/tests/snapshots/invalid_type_in_function_longer_default.snap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ expression: content
44
---
55
delete from public.contacts where id = ~~~uid~~~;
66

7-
invalid input syntax for type integer: uuid
7+
`uid` is of type uuid, not integer

crates/pgt_typecheck/tests/snapshots/invalid_type_in_function_shorter_default.snap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ expression: content
44
---
55
delete from public.contacts where id = ~~~contact_name~~~;
66

7-
invalid input syntax for type integer: text
7+
`contact_name` is of type text, not integer

0 commit comments

Comments
 (0)