Skip to content
Prev Previous commit
Next Next commit
hmmm
  • Loading branch information
juleswritescode committed Apr 4, 2025
commit 6c4812b3ff23832affd45a0ff968e847b54039dc
5 changes: 4 additions & 1 deletion crates/pgt_completions/src/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ pub struct CompletionParams<'a> {
pub tree: Option<&'a tree_sitter::Tree>,
}

#[tracing::instrument(level = "debug")]
#[tracing::instrument(level = "debug", skip_all, fields(
text = params.text,
position = params.position.to_string()
))]
pub fn complete(params: CompletionParams) -> Vec<CompletionItem> {
let ctx = CompletionContext::new(&params);

Expand Down
65 changes: 51 additions & 14 deletions crates/pgt_completions/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,29 @@ impl TryFrom<String> for ClauseType {
}

pub(crate) struct CompletionContext<'a> {
pub ts_node: Option<tree_sitter::Node<'a>>,
pub node_under_cursor: Option<tree_sitter::Node<'a>>,
pub previous_node: Option<tree_sitter::Node<'a>>,

pub tree: Option<&'a tree_sitter::Tree>,
pub text: &'a str,
pub schema_cache: &'a SchemaCache,
pub position: usize,

/// If the cursor of the user is offset to the right of the statement,
/// we'll have to move it back to the last node, otherwise, tree-sitter will break.
/// However, knowing that the user is typing on the "next" node lets us prioritize different completion results.
/// We consider an offset of up to two characters as valid.
///
/// Example:
///
/// ```
/// select * from {}
/// ```
///
/// We'll adjust the cursor position so it lies on the "from" token – but we're looking
/// for table completions.
pub cursor_offset_from_end: bool,

pub schema_name: Option<String>,
pub wrapping_clause_type: Option<ClauseType>,
pub is_invocation: bool,
Expand All @@ -70,7 +87,9 @@ impl<'a> CompletionContext<'a> {
text: &params.text,
schema_cache: params.schema,
position: usize::from(params.position),
ts_node: None,
cursor_offset_from_end: false,
previous_node: None,
node_under_cursor: None,
schema_name: None,
wrapping_clause_type: None,
wrapping_statement_range: None,
Expand All @@ -81,8 +100,6 @@ impl<'a> CompletionContext<'a> {
ctx.gather_tree_context();
ctx.gather_info_from_ts_queries();

println!("Here's my node: {:?}", ctx.ts_node.unwrap());

ctx
}

Expand Down Expand Up @@ -147,30 +164,34 @@ impl<'a> CompletionContext<'a> {
* `select * from use {}` becomes `select * from use{}`.
*/
let current_node = cursor.node();
let position_cache = self.position.clone();
while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 {
self.position -= 1;
}

let cursor_offset = position_cache - self.position;
self.cursor_offset_from_end = cursor_offset > 0 && cursor_offset <= 2;

self.gather_context_from_node(cursor, current_node);
}

fn gather_context_from_node(
&mut self,
mut cursor: tree_sitter::TreeCursor<'a>,
previous_node: tree_sitter::Node<'a>,
parent_node: tree_sitter::Node<'a>,
) {
let current_node = cursor.node();

// prevent infinite recursion – this can happen if we only have a PROGRAM node
if current_node.kind() == previous_node.kind() {
self.ts_node = Some(current_node);
if current_node.kind() == parent_node.kind() {
self.node_under_cursor = Some(current_node);
return;
}

match previous_node.kind() {
match parent_node.kind() {
"statement" | "subquery" => {
self.wrapping_clause_type = current_node.kind().try_into().ok();
self.wrapping_statement_range = Some(previous_node.range());
self.wrapping_statement_range = Some(parent_node.range());
}
"invocation" => self.is_invocation = true,

Expand Down Expand Up @@ -202,7 +223,23 @@ impl<'a> CompletionContext<'a> {

// We have arrived at the leaf node
if current_node.child_count() == 0 {
self.ts_node = Some(current_node);
if self.cursor_offset_from_end {
self.node_under_cursor = None;
self.previous_node = Some(current_node);
} else {
// for the previous node, either select the previous sibling,
// or collect the parent's previous sibling's last child.
let previous = match current_node.prev_sibling() {
Some(n) => Some(n),
None => {
let sib_of_parent = parent_node.prev_sibling();
sib_of_parent.and_then(|p| p.children(&mut cursor).last())
}
};
self.node_under_cursor = Some(current_node);
self.previous_node = previous;
}

return;
}

Expand Down Expand Up @@ -361,7 +398,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some("select"));

Expand Down Expand Up @@ -389,7 +426,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some("from"));
assert_eq!(
Expand All @@ -415,7 +452,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some(""));
assert_eq!(ctx.wrapping_clause_type, None);
Expand All @@ -440,7 +477,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some("fro"));
assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select));
Expand Down
5 changes: 4 additions & 1 deletion crates/pgt_completions/src/relevance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ impl CompletionRelevance<'_> {
}

fn check_matches_query_input(&mut self, ctx: &CompletionContext) {
let node = ctx.ts_node.unwrap();
let node = match ctx.node_under_cursor {
Some(node) => node,
None => return,
};

let content = match ctx.get_ts_node_content(node) {
Some(c) => c,
Expand Down
20 changes: 13 additions & 7 deletions crates/pgt_workspace/src/workspace/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use pgt_analyse::{AnalyserOptions, AnalysisFilter};
use pgt_analyser::{Analyser, AnalyserConfig, AnalyserContext};
use pgt_diagnostics::{Diagnostic, DiagnosticExt, Severity, serde::Diagnostic as SDiagnostic};
use pgt_fs::{ConfigName, PgTPath};
use pgt_text_size::{TextRange, TextSize};
use pgt_typecheck::TypecheckParams;
use schema_cache_manager::SchemaCacheManager;
use sqlx::Executor;
Expand Down Expand Up @@ -535,13 +536,18 @@ impl Workspace for WorkspaceServer {
.get(&params.path)
.ok_or(WorkspaceError::not_found())?;

let (statement, stmt_range, text) = match doc
.iter_statements_with_text_and_range()
.find(|(_, r, _)| r.contains(params.position))
{
Some(s) => s,
None => return Ok(CompletionsResult::default()),
};
let (statement, stmt_range, text) =
match doc.iter_statements_with_text_and_range().find(|(_, r, _)| {
let expanded_range = TextRange::new(
r.start(),
r.end().checked_add(TextSize::new(2)).unwrap_or(r.end()),
);

expanded_range.contains(params.position)
}) {
Some(s) => s,
None => return Ok(CompletionsResult::default()),
};

// `offset` is the position in the document,
// but we need the position within the *statement*.
Expand Down
Loading