Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: clause_type enum
  • Loading branch information
juleswritescode committed Dec 14, 2024
commit 9415a5dd46bb6571477bffd676190aa15c87ec62
7 changes: 5 additions & 2 deletions crates/pg_completions/src/complete.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use text_size::TextSize;

use crate::{
builder::CompletionBuilder, context::CompletionContext, item::CompletionItem,
providers::complete_tables,
builder::CompletionBuilder,
context::CompletionContext,
item::CompletionItem,
providers::{complete_functions, complete_tables},
};

pub const LIMIT: usize = 50;
Expand Down Expand Up @@ -34,6 +36,7 @@ pub fn complete(params: CompletionParams) -> CompletionResult {
let mut builder = CompletionBuilder::new();

complete_tables(&ctx, &mut builder);
complete_functions(&ctx, &mut builder);

builder.finish()
}
Expand Down
32 changes: 28 additions & 4 deletions crates/pg_completions/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,30 @@ use pg_schema_cache::SchemaCache;

use crate::CompletionParams;

#[derive(Debug, PartialEq, Eq)]
pub enum ClauseType {
Select,
Where,
From,
}

impl From<&str> for ClauseType {
fn from(value: &str) -> Self {
match value {
"select" => Self::Select,
"where" => Self::Where,
"from" => Self::From,
_ => panic!("Unimplemented ClauseType: {}", value),
}
}
}

impl From<String> for ClauseType {
fn from(value: String) -> Self {
ClauseType::from(value.as_str())
}
}

pub(crate) struct CompletionContext<'a> {
pub ts_node: Option<tree_sitter::Node<'a>>,
pub tree: Option<&'a tree_sitter::Tree>,
Expand All @@ -10,7 +34,7 @@ pub(crate) struct CompletionContext<'a> {
pub position: usize,

pub schema_name: Option<String>,
pub wrapping_clause_type: Option<String>,
pub wrapping_clause_type: Option<ClauseType>,
pub is_invocation: bool,
}

Expand Down Expand Up @@ -65,7 +89,7 @@ impl<'a> CompletionContext<'a> {
let current_node_kind = current_node.kind();

match previous_node_kind {
"statement" => self.wrapping_clause_type = Some(current_node_kind.to_string()),
"statement" => self.wrapping_clause_type = Some(current_node_kind.into()),
"invocation" => self.is_invocation = true,

_ => {}
Expand All @@ -84,7 +108,7 @@ impl<'a> CompletionContext<'a> {

// in Treesitter, the Where clause is nested inside other clauses
"where" => {
self.wrapping_clause_type = Some("where".to_string());
self.wrapping_clause_type = Some("where".into());
}

_ => {}
Expand Down Expand Up @@ -156,7 +180,7 @@ mod tests {

let ctx = CompletionContext::new(&params);

assert_eq!(ctx.wrapping_clause_type, Some(expected_clause.to_string()));
assert_eq!(ctx.wrapping_clause_type, Some(expected_clause.into()));
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/pg_completions/src/item.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
pub enum CompletionItemKind {
Table,
Function,
Expand Down
81 changes: 71 additions & 10 deletions crates/pg_completions/src/providers/functions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use pg_schema_cache::Function;

use crate::{
builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData,
CompletionItem, CompletionItemKind,
Expand Down Expand Up @@ -27,10 +25,9 @@ pub fn complete_functions(ctx: &CompletionContext, builder: &mut CompletionBuild
#[cfg(test)]
mod tests {
use crate::{
context::CompletionContext,
providers::complete_functions,
complete,
test_helper::{get_test_deps, get_test_params, CURSOR_POS},
CompletionItem,
CompletionItem, CompletionItemKind,
};

#[tokio::test]
Expand All @@ -49,19 +46,83 @@ mod tests {

let query = format!("select coo{}", CURSOR_POS);

let (tree, cache, mut builder) = get_test_deps(setup, &query).await;
let (tree, cache) = get_test_deps(setup, &query).await;
let params = get_test_params(&tree, &cache, &query);
let ctx = CompletionContext::new(&params);
let results = complete(params);

complete_functions(&ctx, &mut builder);
let CompletionItem { label, .. } = results
.into_iter()
.next()
.expect("Should return at least one completion item");

let results = builder.finish();
assert_eq!(label, "cool");
}

let CompletionItem { label, .. } = results
#[tokio::test]
async fn prefers_fn_if_invocation() {
let setup = r#"
create table coos (
id serial primary key,
name text
);

create or replace function cool()
returns trigger
language plpgsql
security invoker
as $$
begin
raise exception 'dont matter';
end;
$$;
"#;

let query = format!(r#"select * from coo{}()"#, CURSOR_POS);

let (tree, cache) = get_test_deps(setup, &query).await;
let params = get_test_params(&tree, &cache, &query);
let results = complete(params);

let CompletionItem { label, kind, .. } = results
.into_iter()
.next()
.expect("Should return at least one completion item");

assert_eq!(label, "cool");
assert_eq!(kind, CompletionItemKind::Function);
}

#[tokio::test]
async fn prefers_fn_in_select_clause() {
let setup = r#"
create table coos (
id serial primary key,
name text
);

create or replace function cool()
returns trigger
language plpgsql
security invoker
as $$
begin
raise exception 'dont matter';
end;
$$;
"#;

let query = format!(r#"select coo{}"#, CURSOR_POS);

let (tree, cache) = get_test_deps(setup, &query).await;
let params = get_test_params(&tree, &cache, &query);
let results = complete(params);

let CompletionItem { label, kind, .. } = results
.into_iter()
.next()
.expect("Should return at least one completion item");

assert_eq!(label, "cool");
assert_eq!(kind, CompletionItemKind::Function);
}
}
43 changes: 43 additions & 0 deletions crates/pg_completions/src/providers/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,46 @@ pub fn complete_tables(ctx: &CompletionContext, builder: &mut CompletionBuilder)
builder.add_item(item);
}
}

mod tests {
use crate::{
complete,
test_helper::{get_test_deps, get_test_params, CURSOR_POS},
CompletionItem, CompletionItemKind,
};

#[tokio::test]
async fn prefers_table_in_from_clause() {
let setup = r#"
create table coos (
id serial primary key,
name text
);

create or replace function cool()
returns trigger
language plpgsql
security invoker
as $$
begin
raise exception 'dont matter';
end;
$$;
"#;

let query = format!(r#"select * from coo{}"#, CURSOR_POS);

let (tree, cache) = get_test_deps(setup, &query).await;
let params = get_test_params(&tree, &cache, &query);

let results = complete(params);

let CompletionItem { label, kind, .. } = results
.into_iter()
.next()
.expect("Should return at least one completion item");

assert_eq!(label, "coos");
assert_eq!(kind, CompletionItemKind::Table);
}
}
11 changes: 3 additions & 8 deletions crates/pg_completions/src/test_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@ use pg_schema_cache::SchemaCache;
use pg_test_utils::test_database::get_new_test_db;
use sqlx::Executor;

use crate::{builder::CompletionBuilder, CompletionParams};
use crate::CompletionParams;

pub static CURSOR_POS: &str = "XXX";

pub(crate) async fn get_test_deps(
setup: &str,
input: &str,
) -> (
tree_sitter::Tree,
pg_schema_cache::SchemaCache,
CompletionBuilder,
) {
) -> (tree_sitter::Tree, pg_schema_cache::SchemaCache) {
let test_db = get_new_test_db().await;

test_db
Expand All @@ -29,9 +25,8 @@ pub(crate) async fn get_test_deps(
.expect("Error loading sql language");

let tree = parser.parse(input, None).unwrap();
let builder = CompletionBuilder::new();

(tree, schema_cache, builder)
(tree, schema_cache)
}

pub(crate) fn get_test_params<'a>(
Expand Down
Loading