|  | 
|  | 1 | +use std::collections::HashMap; | 
| 1 | 2 | use std::num::NonZeroUsize; | 
| 2 | 3 | use std::sync::{Arc, Mutex}; | 
| 3 | 4 | 
 | 
| 4 | 5 | use lru::LruCache; | 
| 5 | 6 | use pgt_query_ext::diagnostics::*; | 
| 6 | 7 | use pgt_text_size::TextRange; | 
|  | 8 | +use pgt_tokenizer::tokenize; | 
| 7 | 9 | 
 | 
| 8 | 10 | use super::statement_identifier::StatementId; | 
| 9 | 11 | 
 | 
| @@ -37,7 +39,7 @@ impl PgQueryStore { | 
| 37 | 39 |  } | 
| 38 | 40 | 
 | 
| 39 | 41 |  let r = Arc::new( | 
| 40 |  | - pgt_query::parse(statement.content()) | 
|  | 42 | + pgt_query::parse(&convert_to_positional_params(statement.content())) | 
| 41 | 43 |  .map_err(SyntaxDiagnostic::from) | 
| 42 | 44 |  .and_then(|ast| { | 
| 43 | 45 |  ast.into_root().ok_or_else(|| { | 
| @@ -87,10 +89,79 @@ impl PgQueryStore { | 
| 87 | 89 |  } | 
| 88 | 90 | } | 
| 89 | 91 | 
 | 
|  | 92 | +/// Converts named parameters in a SQL query string to positional parameters. | 
|  | 93 | +/// | 
|  | 94 | +/// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`) | 
|  | 95 | +/// and replaces them with positional parameters (e.g., `$1`, `$2`, etc.). | 
|  | 96 | +/// | 
|  | 97 | +/// It maintains the original spacing of the named parameters in the output string. | 
|  | 98 | +/// | 
|  | 99 | +/// Useful for preparing SQL queries for parsing or execution where named paramters are not supported. | 
|  | 100 | +pub fn convert_to_positional_params(text: &str) -> String { | 
|  | 101 | + let mut result = String::with_capacity(text.len()); | 
|  | 102 | + let mut param_mapping: HashMap<&str, usize> = HashMap::new(); | 
|  | 103 | + let mut param_index = 1; | 
|  | 104 | + let mut position = 0; | 
|  | 105 | + | 
|  | 106 | + for token in tokenize(text) { | 
|  | 107 | + let token_len = token.len as usize; | 
|  | 108 | + let token_text = &text[position..position + token_len]; | 
|  | 109 | + | 
|  | 110 | + if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) { | 
|  | 111 | + let idx = match param_mapping.get(token_text) { | 
|  | 112 | + Some(&index) => index, | 
|  | 113 | + None => { | 
|  | 114 | + let index = param_index; | 
|  | 115 | + param_mapping.insert(token_text, index); | 
|  | 116 | + param_index += 1; | 
|  | 117 | + index | 
|  | 118 | + } | 
|  | 119 | + }; | 
|  | 120 | + | 
|  | 121 | + let replacement = format!("${}", idx); | 
|  | 122 | + let original_len = token_text.len(); | 
|  | 123 | + let replacement_len = replacement.len(); | 
|  | 124 | + | 
|  | 125 | + result.push_str(&replacement); | 
|  | 126 | + | 
|  | 127 | + // maintain original spacing | 
|  | 128 | + if replacement_len < original_len { | 
|  | 129 | + result.push_str(&" ".repeat(original_len - replacement_len)); | 
|  | 130 | + } | 
|  | 131 | + } else { | 
|  | 132 | + result.push_str(token_text); | 
|  | 133 | + } | 
|  | 134 | + | 
|  | 135 | + position += token_len; | 
|  | 136 | + } | 
|  | 137 | + | 
|  | 138 | + result | 
|  | 139 | +} | 
|  | 140 | + | 
| 90 | 141 | #[cfg(test)] | 
| 91 | 142 | mod tests { | 
| 92 | 143 |  use super::*; | 
| 93 | 144 | 
 | 
|  | 145 | + #[test] | 
|  | 146 | + fn test_convert_to_positional_params() { | 
|  | 147 | + let input = "select * from users where id = @one and name = :two and email = :'three';"; | 
|  | 148 | + let result = convert_to_positional_params(input); | 
|  | 149 | + assert_eq!( | 
|  | 150 | + result, | 
|  | 151 | + "select * from users where id = $1 and name = $2 and email = $3 ;" | 
|  | 152 | + ); | 
|  | 153 | + } | 
|  | 154 | + | 
|  | 155 | + #[test] | 
|  | 156 | + fn test_convert_to_positional_params_with_duplicates() { | 
|  | 157 | + let input = "select * from users where first_name = @one and starts_with(email, @one) and created_at > @two;"; | 
|  | 158 | + let result = convert_to_positional_params(input); | 
|  | 159 | + assert_eq!( | 
|  | 160 | + result, | 
|  | 161 | + "select * from users where first_name = $1 and starts_with(email, $1 ) and created_at > $2 ;" | 
|  | 162 | + ); | 
|  | 163 | + } | 
|  | 164 | + | 
| 94 | 165 |  #[test] | 
| 95 | 166 |  fn test_plpgsql_syntax_error() { | 
| 96 | 167 |  let input = " | 
|  | 
0 commit comments