Skip to content
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pg_schema_cache = { path = "./crates/pg_schema_cache", version = "0.0.
pg_statement_splitter = { path = "./crates/pg_statement_splitter", version = "0.0.0" }
pg_syntax = { path = "./crates/pg_syntax", version = "0.0.0" }
pg_text_edit = { path = "./crates/pg_text_edit", version = "0.0.0" }
pg_treesitter_queries = { path = "./crates/pg_treesitter_queries", version = "0.0.0" }
pg_type_resolver = { path = "./crates/pg_type_resolver", version = "0.0.0" }
pg_typecheck = { path = "./crates/pg_typecheck", version = "0.0.0" }
pg_workspace = { path = "./crates/pg_workspace", version = "0.0.0" }
Expand Down
1 change: 1 addition & 0 deletions crates/pg_completions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ serde_json = { workspace = true }
pg_schema_cache.workspace = true
tree-sitter.workspace = true
tree_sitter_sql.workspace = true
pg_treesitter_queries.workspace = true

sqlx.workspace = true

Expand Down
3 changes: 2 additions & 1 deletion crates/pg_completions/src/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
builder::CompletionBuilder,
context::CompletionContext,
item::CompletionItem,
providers::{complete_functions, complete_tables},
providers::{complete_columns, complete_functions, complete_tables},
};

pub const LIMIT: usize = 50;
Expand Down Expand Up @@ -38,6 +38,7 @@ pub fn complete(params: CompletionParams) -> CompletionResult {

complete_tables(&ctx, &mut builder);
complete_functions(&ctx, &mut builder);
complete_columns(&ctx, &mut builder);

builder.finish()
}
89 changes: 71 additions & 18 deletions crates/pg_completions/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
use std::collections::{HashMap, HashSet};

use pg_schema_cache::SchemaCache;
use pg_treesitter_queries::{
queries::{self, QueryResult},
TreeSitterQueriesExecutor,
};

use crate::CompletionParams;

Expand Down Expand Up @@ -52,27 +58,72 @@ pub(crate) struct CompletionContext<'a> {
pub schema_name: Option<String>,
pub wrapping_clause_type: Option<ClauseType>,
pub is_invocation: bool,
pub wrapping_statement_range: Option<tree_sitter::Range>,

pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
}

impl<'a> CompletionContext<'a> {
pub fn new(params: &'a CompletionParams) -> Self {
pub fn new(params: &'a CompletionParams<'a>) -> Self {
let mut ctx = Self {
tree: params.tree,
text: &params.text,
schema_cache: params.schema,
position: usize::from(params.position),

ts_node: None,
schema_name: None,
wrapping_clause_type: None,
wrapping_statement_range: None,
is_invocation: false,
mentioned_relations: HashMap::new(),
};

ctx.gather_tree_context();
ctx.gather_info_from_ts_queries();

dbg!(ctx.wrapping_statement_range);

ctx
}

fn gather_info_from_ts_queries(&mut self) {
let tree = match self.tree.as_ref() {
None => return,
Some(t) => t,
};

let stmt_range = self.wrapping_statement_range.as_ref();
let sql = self.text;

dbg!(sql);

let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), self.text);

executor.add_query_results::<queries::RelationMatch>();

for relation_match in executor.get_iter(stmt_range) {
match relation_match {
QueryResult::Relation(r) => {
let schema_name = r.get_schema(sql);
let table_name = r.get_table(sql);

let current = self.mentioned_relations.get_mut(&schema_name);

match current {
Some(c) => {
c.insert(table_name);
}
None => {
let mut new = HashSet::new();
new.insert(table_name);
self.mentioned_relations.insert(schema_name, new);
}
};
}
};
}
}

pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> {
let source = self.text;
match ts_node.utf8_text(source.as_bytes()) {
Expand Down Expand Up @@ -100,36 +151,38 @@ impl<'a> CompletionContext<'a> {
* We'll therefore adjust the cursor position such that it meets the last node of the AST.
* `select * from use {}` becomes `select * from use{}`.
*/
let current_node_kind = cursor.node().kind();
let current_node = cursor.node();
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
self.position -= 1;
}

self.gather_context_from_node(cursor, current_node_kind);
self.gather_context_from_node(cursor, current_node);
}

fn gather_context_from_node(
&mut self,
mut cursor: tree_sitter::TreeCursor<'a>,
previous_node_kind: &str,
previous_node: tree_sitter::Node<'a>,
) {
let current_node = cursor.node();
let current_node_kind = current_node.kind();

// prevent infinite recursion – this can happen if we only have a PROGRAM node
if current_node_kind == previous_node_kind {
if current_node.kind() == previous_node.kind() {
self.ts_node = Some(current_node);
return;
}

match previous_node_kind {
"statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(),
match previous_node.kind() {
"statement" | "subquery" => {
self.wrapping_clause_type = current_node.kind().try_into().ok();
self.wrapping_statement_range = Some(previous_node.range());
}
"invocation" => self.is_invocation = true,

_ => {}
}

match current_node_kind {
match current_node.kind() {
"object_reference" => {
let txt = self.get_ts_node_content(current_node);
if let Some(txt) = txt {
Expand Down Expand Up @@ -159,7 +212,7 @@ impl<'a> CompletionContext<'a> {
}

cursor.goto_first_child_for_byte(self.position);
self.gather_context_from_node(cursor, current_node_kind);
self.gather_context_from_node(cursor, current_node);
}
}

Expand Down Expand Up @@ -209,7 +262,7 @@ mod tests {
];

for (query, expected_clause) in test_cases {
let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand Down Expand Up @@ -242,7 +295,7 @@ mod tests {
];

for (query, expected_schema) in test_cases {
let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());
let params = crate::CompletionParams {
Expand Down Expand Up @@ -276,7 +329,7 @@ mod tests {
];

for (query, is_invocation) in test_cases {
let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());
let params = crate::CompletionParams {
Expand All @@ -300,7 +353,7 @@ mod tests {
];

for query in cases {
let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand Down Expand Up @@ -328,7 +381,7 @@ mod tests {
fn does_not_fail_on_trailing_whitespace() {
let query = format!("select * from {}", CURSOR_POS);

let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand All @@ -354,7 +407,7 @@ mod tests {
fn does_not_fail_with_empty_statements() {
let query = format!("{}", CURSOR_POS);

let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand All @@ -379,7 +432,7 @@ mod tests {
// is selecting a certain column name, such as `frozen_account`.
let query = format!("select * fro{}", CURSOR_POS);

let (position, text) = get_text_and_position(query.as_str());
let (position, text) = get_text_and_position(query.as_str().into());

let tree = get_tree(text.as_str());

Expand Down
1 change: 1 addition & 0 deletions crates/pg_completions/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize};
pub enum CompletionItemKind {
Table,
Function,
Column,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
114 changes: 114 additions & 0 deletions crates/pg_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use crate::{
builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData,
CompletionItem, CompletionItemKind,
};

pub fn complete_columns(ctx: &CompletionContext, builder: &mut CompletionBuilder) {
let available_columns = &ctx.schema_cache.columns;

for col in available_columns {
let item = CompletionItem {
label: col.name.clone(),
score: CompletionRelevanceData::Column(col).get_score(ctx),
description: format!("Table: {}.{}", col.schema_name, col.table_name),
preselected: false,
kind: CompletionItemKind::Column,
};

builder.add_item(item);
}
}

#[cfg(test)]
mod tests {
use crate::{
complete,
test_helper::{get_test_deps, get_test_params, InputQuery, CURSOR_POS},
CompletionItem,
};

struct TestCase {
query: String,
message: &'static str,
label: &'static str,
description: &'static str,
}

impl TestCase {
fn get_input_query(&self) -> InputQuery {
let strs: Vec<&str> = self.query.split_whitespace().collect();
strs.join(" ").as_str().into()
}
}

#[tokio::test]
async fn completes_columns() {
let setup = r#"
create schema private;
create table public.users (
id serial primary key,
name text
);
create table public.audio_books (
id serial primary key,
narrator text
);
create table private.audio_books (
id serial primary key,
narrator_id text
);
"#;

let queries: Vec<TestCase> = vec![
TestCase {
message: "correctly prefers the columns of present tables",
query: format!(r#"select na{} from public.audio_books;"#, CURSOR_POS),
label: "narrator",
description: "Table: public.audio_books",
},
TestCase {
message: "correctly handles nested queries",
query: format!(
r#"
select
*
from (
select id, na{}
from private.audio_books
) as subquery
join public.users u
on u.id = subquery.id;
"#,
CURSOR_POS
),
label: "narrator_id",
description: "Table: private.audio_books",
},
TestCase {
message: "works without a schema",
query: format!(r#"select na{} from users;"#, CURSOR_POS),
label: "name",
description: "Table: public.users",
},
];

for q in queries {
let (tree, cache) = get_test_deps(setup, q.get_input_query()).await;
let params = get_test_params(&tree, &cache, q.get_input_query());
let results = complete(params);

let CompletionItem {
label, description, ..
} = results
.into_iter()
.next()
.expect("Should return at least one completion item");

assert_eq!(label, q.label, "{}", q.message);
assert_eq!(description, q.description, "{}", q.message);
}
}
}
Loading
Loading