Skip to content

Commit 50248d5

Browse files
authored
Merge pull request #72 from supabase/feat/create-aggregate
2 parents 941a7b6 + 0753912 commit 50248d5

File tree

4 files changed

+65
-4
lines changed

4 files changed

+65
-4
lines changed

crates/codegen/src/get_node_properties.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ pub fn get_node_properties_mod(proto_file: &ProtoFile) -> proc_macro2::TokenStre
127127
}
128128
}
129129

130-
pub fn get_node_properties(node: &NodeEnum) -> Vec<TokenProperty> {
130+
pub fn get_node_properties(node: &NodeEnum, parent: Option<&NodeEnum>) -> Vec<TokenProperty> {
131131
let mut tokens: Vec<TokenProperty> = Vec::new();
132132

133133
match node {
@@ -515,6 +515,56 @@ fn custom_handlers(node: &Node) -> TokenStream {
515515
tokens.push(TokenProperty::from(Token::As));
516516
}
517517
},
518+
"List" => quote! {
519+
if parent.is_some() {
520+
// if parent is `DefineStmt`, we need to check whether an ORDER BY needs to be added
521+
if let NodeEnum::DefineStmt(define_stmt) = parent.unwrap() {
522+
// there *seems* to be an integer node in the last position of the DefineStmt args that
523+
// defines whether the list contains an order by statement
524+
let integer = define_stmt.args.last()
525+
.and_then(|node| node.node.as_ref())
526+
.and_then(|node| if let NodeEnum::Integer(n) = node { Some(n.ival) } else { None });
527+
if integer.is_none() {
528+
panic!("DefineStmt of type ObjectAggregate has no integer node in last position of args");
529+
}
530+
// if the integer is 1, then there is an order by statement
531+
// we add it to the `List` node because that seems to make most sense based off the grammar definition
532+
// ref: https://github.com/postgres/postgres/blob/REL_15_STABLE/src/backend/parser/gram.y#L8355
533+
// ```
534+
// aggr_args:
535+
// | '(' aggr_args_list ORDER BY aggr_args_list ')'
536+
// ```
537+
if integer.unwrap() == 1 {
538+
tokens.push(TokenProperty::from(Token::Order));
539+
tokens.push(TokenProperty::from(Token::By));
540+
}
541+
}
542+
}
543+
},
544+
"DefineStmt" => quote! {
545+
tokens.push(TokenProperty::from(Token::Create));
546+
if n.replace {
547+
tokens.push(TokenProperty::from(Token::Or));
548+
tokens.push(TokenProperty::from(Token::Replace));
549+
}
550+
match n.kind() {
551+
protobuf::ObjectType::ObjectAggregate => {
552+
tokens.push(TokenProperty::from(Token::Aggregate));
553+
554+
// n.args is always an array with two nodes
555+
assert_eq!(n.args.len(), 2, "DefineStmt of type ObjectAggregate does not have exactly 2 args");
556+
// the first is either a List or a Node { node: None }
557+
558+
if let Some(node) = &n.args.first() {
559+
if node.node.is_none() {
560+
// if first element is a Node { node: None }, then it's "*"
561+
tokens.push(TokenProperty::from(Token::Ascii42));
562+
} }
563+
// if its a list, we handle it in the handler for `List`
564+
},
565+
_ => panic!("Unknown DefineStmt {:#?}", n.kind()),
566+
}
567+
},
518568
"CreateSchemaStmt" => quote! {
519569
tokens.push(TokenProperty::from(Token::Create));
520570
tokens.push(TokenProperty::from(Token::Schema));

crates/codegen/src/get_nodes.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub fn get_nodes_mod(proto_file: &ProtoFile) -> proc_macro2::TokenStream {
2525
let root_node_idx = g.add_node(Node {
2626
kind: SyntaxKind::from(node),
2727
depth: at_depth,
28-
properties: get_node_properties(node),
28+
properties: get_node_properties(node, None),
2929
location: get_location(node),
3030
});
3131

@@ -45,12 +45,12 @@ pub fn get_nodes_mod(proto_file: &ProtoFile) -> proc_macro2::TokenStream {
4545
NodeEnum::BitString(n) => true,
4646
_ => false
4747
} {
48-
g[parent_idx].properties.extend(get_node_properties(&c));
48+
g[parent_idx].properties.extend(get_node_properties(&c, Some(&node)));
4949
} else {
5050
let node_idx = g.add_node(Node {
5151
kind: SyntaxKind::from(&c),
5252
depth: current_depth,
53-
properties: get_node_properties(&c),
53+
properties: get_node_properties(&c, Some(&node)),
5454
location: get_location(&c),
5555
});
5656
g.add_edge(parent_idx, node_idx, ());

crates/parser/src/parse/libpg_query_node.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ impl<'p> LibpgQueryNodeParser<'p> {
5252
) -> LibpgQueryNodeParser<'p> {
5353
let current_depth = parser.depth.clone();
5454
debug!("Parsing node {:#?}", node);
55+
println!("Parsing node {:#?}", node);
5556
Self {
5657
parser,
5758
token_range,
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
CREATE AGGREGATE aggregate1 (int4) (sfunc = sfunc1, stype = stype1);
2+
CREATE AGGREGATE aggregate1 (int4, bool) (sfunc = sfunc1, stype = stype1);
3+
CREATE AGGREGATE aggregate1 (*) (sfunc = sfunc1, stype = stype1);
4+
CREATE AGGREGATE aggregate1 (int4) (sfunc = sfunc1, stype = stype1, finalfunc_extra, mfinalfuncextra);
5+
CREATE AGGREGATE aggregate1 (int4) (sfunc = sfunc1, stype = stype1, finalfunc_modify = read_only, parallel = restricted);
6+
CREATE AGGREGATE percentile_disc (float8 ORDER BY anyelement) (sfunc = ordered_set_transition, stype = internal, finalfunc = percentile_disc_final, finalfunc_extra);
7+
CREATE AGGREGATE custom_aggregate (float8 ORDER BY column1, column2) (sfunc = sfunc1, stype = stype1);
8+
9+
10+

0 commit comments

Comments
 (0)