Skip to content
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pg_schema_cache = { path = "./crates/pg_schema_cache", version = "0.0.
pg_statement_splitter = { path = "./crates/pg_statement_splitter", version = "0.0.0" }
pg_syntax = { path = "./crates/pg_syntax", version = "0.0.0" }
pg_text_edit = { path = "./crates/pg_text_edit", version = "0.0.0" }
pg_treesitter_queries = { path = "./crates/pg_treesitter_queries", version = "0.0.0" }
pg_type_resolver = { path = "./crates/pg_type_resolver", version = "0.0.0" }
pg_typecheck = { path = "./crates/pg_typecheck", version = "0.0.0" }
pg_workspace = { path = "./crates/pg_workspace", version = "0.0.0" }
Expand Down
1 change: 1 addition & 0 deletions crates/pg_completions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ serde_json = { workspace = true }
pg_schema_cache.workspace = true
tree-sitter.workspace = true
tree_sitter_sql.workspace = true
pg_treesitter_queries.workspace = true

sqlx.workspace = true

Expand Down
5 changes: 3 additions & 2 deletions crates/pg_completions/src/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
builder::CompletionBuilder,
context::CompletionContext,
item::CompletionItem,
providers::{complete_functions, complete_tables},
providers::{complete_columns, complete_functions, complete_tables},
};

pub const LIMIT: usize = 50;
Expand All @@ -31,13 +31,14 @@ impl IntoIterator for CompletionResult {
}
}

pub fn complete(params: CompletionParams) -> CompletionResult {
pub fn complete<'a>(params: CompletionParams<'a>) -> CompletionResult {
let ctx = CompletionContext::new(&params);

let mut builder = CompletionBuilder::new();

complete_tables(&ctx, &mut builder);
complete_functions(&ctx, &mut builder);
complete_columns(&ctx, &mut builder);

builder.finish()
}
74 changes: 64 additions & 10 deletions crates/pg_completions/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
use std::{
collections::{HashMap, HashSet},
hash::Hash,
ops::Range,
};

use pg_schema_cache::SchemaCache;
use pg_treesitter_queries::{
queries::{self, QueryResult},
TreeSitterQueriesExecutor,
};

use crate::CompletionParams;

Expand Down Expand Up @@ -52,10 +62,13 @@ pub(crate) struct CompletionContext<'a> {
pub schema_name: Option<String>,
pub wrapping_clause_type: Option<ClauseType>,
pub is_invocation: bool,
pub wrapping_statement_range: Option<Range<usize>>,

pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,
}

impl<'a> CompletionContext<'a> {
pub fn new(params: &'a CompletionParams) -> Self {
pub fn new(params: &'a CompletionParams<'a>) -> Self {
let mut ctx = Self {
tree: params.tree,
text: &params.text,
Expand All @@ -65,14 +78,53 @@ impl<'a> CompletionContext<'a> {
ts_node: None,
schema_name: None,
wrapping_clause_type: None,
wrapping_statement_range: None,
is_invocation: false,
mentioned_relations: HashMap::new(),
};

ctx.gather_tree_context();
ctx.gather_info_from_ts_queries();

ctx
}

fn gather_info_from_ts_queries(&mut self) {
let tree = match self.tree.as_ref() {
None => return,
Some(t) => t,
};

let stmt_range = self.wrapping_statement_range.as_ref();
let sql = self.text;

let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), self.text);

executor.add_query_results::<queries::RelationMatch>();

for relation_match in executor.get_iter(stmt_range) {
match relation_match {
QueryResult::Relation(r) => {
let schema_name = r.get_schema(sql);
let table_name = r.get_table(sql);

let current = self.mentioned_relations.get_mut(&schema_name);

match current {
Some(c) => {
c.insert(table_name);
}
None => {
let mut new = HashSet::new();
new.insert(table_name);
self.mentioned_relations.insert(schema_name, new);
}
};
}
};
}
}

pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> {
let source = self.text;
match ts_node.utf8_text(source.as_bytes()) {
Expand Down Expand Up @@ -100,36 +152,38 @@ impl<'a> CompletionContext<'a> {
* We'll therefore adjust the cursor position such that it meets the last node of the AST.
* `select * from use {}` becomes `select * from use{}`.
*/
let current_node_kind = cursor.node().kind();
let current_node = cursor.node();
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
self.position -= 1;
}

self.gather_context_from_node(cursor, current_node_kind);
self.gather_context_from_node(cursor, current_node);
}

fn gather_context_from_node(
&mut self,
mut cursor: tree_sitter::TreeCursor<'a>,
previous_node_kind: &str,
previous_node: tree_sitter::Node<'a>,
) {
let current_node = cursor.node();
let current_node_kind = current_node.kind();

// prevent infinite recursion – this can happen if we only have a PROGRAM node
if current_node_kind == previous_node_kind {
if current_node.kind() == previous_node.kind() {
self.ts_node = Some(current_node);
return;
}

match previous_node_kind {
"statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(),
match previous_node.kind() {
"statement" => {
self.wrapping_clause_type = current_node.kind().try_into().ok();
self.wrapping_statement_range = Some(previous_node.byte_range());
}
"invocation" => self.is_invocation = true,

_ => {}
}

match current_node_kind {
match current_node.kind() {
"object_reference" => {
let txt = self.get_ts_node_content(current_node);
if let Some(txt) = txt {
Expand Down Expand Up @@ -159,7 +213,7 @@ impl<'a> CompletionContext<'a> {
}

cursor.goto_first_child_for_byte(self.position);
self.gather_context_from_node(cursor, current_node_kind);
self.gather_context_from_node(cursor, current_node);
}
}

Expand Down
20 changes: 20 additions & 0 deletions crates/pg_completions/src/providers/columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use crate::{
builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData,
CompletionItem, CompletionItemKind,
};

pub fn complete_columns(ctx: &CompletionContext, builder: &mut CompletionBuilder) {
let available_columns = &ctx.schema_cache.columns;

for col in available_columns {
let item = CompletionItem {
label: col.name.clone(),
score: CompletionRelevanceData::Column(col).get_score(ctx),
description: format!("Table: {}.{}", col.schema_name, col.table_name),
preselected: false,
kind: CompletionItemKind::Function,
};

builder.add_item(item);
}
}
2 changes: 2 additions & 0 deletions crates/pg_completions/src/providers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
mod columns;
mod functions;
mod tables;

pub use columns::*;
pub use functions::*;
pub use tables::*;
57 changes: 53 additions & 4 deletions crates/pg_completions/src/relevance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::context::{ClauseType, CompletionContext};
pub(crate) enum CompletionRelevanceData<'a> {
Table(&'a pg_schema_cache::Table),
Function(&'a pg_schema_cache::Function),
Column(&'a pg_schema_cache::Column),
}

impl<'a> CompletionRelevanceData<'a> {
Expand Down Expand Up @@ -34,6 +35,7 @@ impl<'a> CompletionRelevance<'a> {
self.check_if_catalog(ctx);
self.check_is_invocation(ctx);
self.check_matching_clause_type(ctx);
self.check_relations_in_stmt(ctx);

self.score
}
Expand All @@ -49,6 +51,7 @@ impl<'a> CompletionRelevance<'a> {
let name = match self.data {
CompletionRelevanceData::Function(f) => f.name.as_str(),
CompletionRelevanceData::Table(t) => t.name.as_str(),
CompletionRelevanceData::Column(c) => c.name.as_str(),
};

if name.starts_with(content) {
Expand Down Expand Up @@ -79,6 +82,11 @@ impl<'a> CompletionRelevance<'a> {
ClauseType::From => 0,
_ => -50,
},
CompletionRelevanceData::Column(_) => match clause_type {
ClauseType::Select => 15,
ClauseType::Where => 15,
_ => -15,
},
}
}

Expand Down Expand Up @@ -107,10 +115,7 @@ impl<'a> CompletionRelevance<'a> {
Some(n) => n,
};

let data_schema = match self.data {
CompletionRelevanceData::Function(f) => f.schema.as_str(),
CompletionRelevanceData::Table(t) => t.schema.as_str(),
};
let data_schema = self.get_schema_name();

if schema_name == data_schema {
self.score += 25;
Expand All @@ -119,11 +124,55 @@ impl<'a> CompletionRelevance<'a> {
}
}

fn get_schema_name(&self) -> &str {
match self.data {
CompletionRelevanceData::Function(f) => f.schema.as_str(),
CompletionRelevanceData::Table(t) => t.schema.as_str(),
CompletionRelevanceData::Column(c) => c.schema_name.as_str(),
}
}

fn get_table_name(&self) -> Option<&str> {
match self.data {
CompletionRelevanceData::Column(c) => Some(c.table_name.as_str()),
CompletionRelevanceData::Table(t) => Some(t.name.as_str()),
_ => None,
}
}

fn check_if_catalog(&mut self, ctx: &CompletionContext) {
if ctx.schema_name.as_ref().is_some_and(|n| n == "pg_catalog") {
return;
}

self.score -= 5; // unlikely that the user wants schema data
}

fn check_relations_in_stmt(&mut self, ctx: &CompletionContext) {
match self.data {
CompletionRelevanceData::Table(_) => return,
CompletionRelevanceData::Function(_) => return,
_ => {}
}

let schema = self.get_schema_name().to_string();
let table_name = match self.get_table_name() {
Some(t) => t,
None => return,
};

if ctx
.mentioned_relations
.get(&Some(schema.to_string()))
.is_some_and(|tables| tables.contains(table_name))
{
self.score += 45;
} else if ctx
.mentioned_relations
.get(&None)
.is_some_and(|tables| tables.contains(table_name))
{
self.score += 30;
}
}
}
1 change: 1 addition & 0 deletions crates/pg_schema_cache/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod tables;
mod types;
mod versions;

pub use columns::*;
pub use functions::{Behavior, Function, FunctionArg, FunctionArgs};
pub use schema_cache::SchemaCache;
pub use tables::{ReplicaIdentity, Table};
4 changes: 4 additions & 0 deletions crates/pg_test_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ version = "0.0.0"
name = "tree_print"
path = "src/bin/tree_print.rs"

[[bin]]
name = "query_debug"
path = "src/bin/tree_query_debug.rs"

[dependencies]
anyhow = "1.0.81"
clap = { version = "4.5.23", features = ["derive"] }
Expand Down
Loading
Loading