Skip to content
Merged
Next Next commit
refactor: parser
  • Loading branch information
psteinroe committed Apr 9, 2025
commit 34a9687245d77f8bd024090d904cbe01b79ade9b
45 changes: 45 additions & 0 deletions crates/pgt_query_ext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,48 @@ pub fn parse(sql: &str) -> Result<NodeEnum> {
.ok_or_else(|| Error::Parse("Unable to find root node".to_string()))
})?
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_sql_1() {
let input = "CREATE FUNCTION add(integer, integer) RETURNS integer
AS 'select $1 + $2;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;";
println!("{:#?}", parse(input).unwrap());
// print after 42
println!("{:#?}", &input[42..]);
}

#[test]
fn test_sql_2() {
let input = "CREATE FUNCTION add() RETURNS integer
AS $sql$select 1 + 2;$sql$
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;";
println!("{:#?}", parse(input).unwrap());
// print after 58
println!("{:#?}", &input[58..]);
}

#[test]
fn test_plpsql() {
let input = "CREATE FUNCTION add(integer, integer) RETURNS integer
AS $s$
begin
return $1 + $2;
end
$s$
LANGUAGE plpgsql
IMMUTABLE
RETURNS NULL ON NULL INPUT;";
println!("{:#?}", parse(input).unwrap());
// print after 58
println!("{:#?}", &input[58..]);
}
}
104 changes: 97 additions & 7 deletions crates/pgt_workspace/src/workspace/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{fs, panic::RefUnwindSafe, path::Path, sync::RwLock};
use std::{
fs,
panic::RefUnwindSafe,
path::Path,
sync::{Arc, Mutex, RwLock},
};

use analyser::AnalyserVisitorBuilder;
use async_helper::run_async;
Expand All @@ -15,9 +20,11 @@ use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as
use pgt_fs::{ConfigName, PgTPath};
use pgt_typecheck::TypecheckParams;
use schema_cache_manager::SchemaCacheManager;
use sql_function::SQLFunctionBodyStore;
use sqlx::Executor;
use tracing::info;
use tree_sitter::TreeSitterStore;
use tree_sitter_parser::TreeSitterParserStore;

use crate::{
WorkspaceError,
Expand All @@ -44,8 +51,11 @@ mod change;
mod db_connection;
mod document;
mod migration;
mod parser;
mod pg_query;
mod schema_cache_manager;
mod sql_function;
mod statement_identifier;
mod tree_sitter;

pub(super) struct WorkspaceServer {
Expand All @@ -60,6 +70,8 @@ pub(super) struct WorkspaceServer {

tree_sitter: TreeSitterStore,
pg_query: PgQueryStore,
sql_functions: SQLFunctionBodyStore,
ts_parser: TreeSitterParserStore,

connection: RwLock<DbConnection>,
}
Expand All @@ -86,6 +98,8 @@ impl WorkspaceServer {
pg_query: PgQueryStore::new(),
schema_cache: SchemaCacheManager::default(),
connection: RwLock::default(),
sql_functions: SQLFunctionBodyStore::new(),
ts_parser: TreeSitterParserStore::new(),
}
}

Expand Down Expand Up @@ -183,10 +197,17 @@ impl Workspace for WorkspaceServer {
fn open_file(&self, params: OpenFileParams) -> Result<(), WorkspaceError> {
let doc = Document::new(params.path.clone(), params.content, params.version);

let doc_parser = self.ts_parser.get_parser(params.path.clone());
let mut parser = doc_parser.lock().expect("Error locking parser");
doc.iter_statements_with_text().for_each(|(stmt, content)| {
self.tree_sitter.add_statement(&stmt, content);
self.tree_sitter.add_statement(&mut parser, &stmt, content);
self.pg_query.add_statement(&stmt, content);
if let Some(ast) = self.pg_query.get_ast(&stmt) {
self.sql_functions
.add_statement(&mut parser, &ast, &stmt, content);
}
});
drop(parser);

self.documents.insert(params.path, doc);

Expand All @@ -203,8 +224,11 @@ impl Workspace for WorkspaceServer {
for stmt in doc.iter_statements() {
self.tree_sitter.remove_statement(&stmt);
self.pg_query.remove_statement(&stmt);
self.sql_functions.remove_statement(&stmt);
}

self.ts_parser.remove_parser(&params.path);

Ok(())
}

Expand All @@ -223,6 +247,8 @@ impl Workspace for WorkspaceServer {
params.version,
));

let doc_parser = self.ts_parser.get_parser(params.path.clone());
let mut parser = doc_parser.lock().expect("Error locking parser");
for c in &doc.apply_file_change(&params) {
match c {
StatementChange::Added(added) => {
Expand All @@ -232,8 +258,17 @@ impl Workspace for WorkspaceServer {
added.stmt.path.as_os_str().to_str(),
added.text
);
self.tree_sitter.add_statement(&added.stmt, &added.text);
self.tree_sitter
.add_statement(&mut parser, &added.stmt, &added.text);
self.pg_query.add_statement(&added.stmt, &added.text);
if let Some(ast) = self.pg_query.get_ast(&added.stmt) {
self.sql_functions.add_statement(
&mut parser,
&ast,
&added.stmt,
&added.text,
);
}
}
StatementChange::Deleted(s) => {
tracing::debug!(
Expand All @@ -243,6 +278,7 @@ impl Workspace for WorkspaceServer {
);
self.tree_sitter.remove_statement(s);
self.pg_query.remove_statement(s);
self.sql_functions.remove_statement(s);
}
StatementChange::Modified(s) => {
tracing::debug!(
Expand All @@ -256,11 +292,15 @@ impl Workspace for WorkspaceServer {
s.change_text
);

self.tree_sitter.modify_statement(s);
self.tree_sitter.modify_statement(&mut parser, s);
self.pg_query.modify_statement(s);
if let Some(ast) = self.pg_query.get_ast(&s.new_stmt) {
self.sql_functions.modify_statement(&mut parser, &ast, s);
}
}
}
}
drop(parser);

Ok(())
}
Expand Down Expand Up @@ -420,10 +460,25 @@ impl Workspace for WorkspaceServer {
{
let typecheck_params: Vec<_> = doc
.iter_statements_with_text_and_range()
.map(|(stmt, range, text)| {
.flat_map(|(stmt, range, text)| {
let ast = self.pg_query.get_ast(&stmt);
let tree = self.tree_sitter.get_parse_tree(&stmt);
(text.to_string(), ast, tree, *range)

let mut res = vec![(text.to_string(), ast, tree, *range)];

if let Some(fn_body) = self.sql_functions.get_function_body(&stmt) {
// fn_body range is within the statement -> adjust it to be relative to the
// document instead (as the other ranges)
let fn_range = fn_body.range + range.start();
res.push((
fn_body.body.clone(),
fn_body.ast.clone().map(Arc::new),
Some(Arc::new(fn_body.cst.clone())),
fn_range,
));
}

res
})
.collect();

Expand Down Expand Up @@ -479,6 +534,27 @@ impl Workspace for WorkspaceServer {
);
}

if let Some(fn_body) = self.sql_functions.get_function_body(&stmt) {
if let Some(fn_diag) = &fn_body.syntax_diagnostics {
stmt_diagnostics.push(SDiagnostic::new(
fn_diag
.clone()
.with_file_path(params.path.as_path().display().to_string())
.with_file_span(fn_body.range + r.start()),
));
}

if let Some(ast) = &fn_body.ast {
stmt_diagnostics.extend(
analyser
.run(AnalyserContext { root: ast })
.into_iter()
.map(SDiagnostic::new)
.collect::<Vec<_>>(),
);
}
}

stmt_diagnostics
.into_iter()
.map(|d| {
Expand Down Expand Up @@ -559,13 +635,27 @@ impl Workspace for WorkspaceServer {

let schema_cache = self.schema_cache.load(pool)?;

let items = pgt_completions::complete(pgt_completions::CompletionParams {
let mut items = pgt_completions::complete(pgt_completions::CompletionParams {
position,
schema: schema_cache.as_ref(),
tree: tree.as_deref(),
text: text.to_string(),
});

if let Some(f) = self.sql_functions.get_function_body(&statement) {
let fn_text = f.body.clone();
let fn_range = f.range + stmt_range.start();

items.extend(pgt_completions::complete(
pgt_completions::CompletionParams {
position: position - fn_range.start(),
schema: schema_cache.as_ref(),
tree: Some(&f.cst),
text: fn_text,
},
));
}

Ok(CompletionsResult { items })
}
}
Expand Down
Loading