Skip to content
Prev Previous commit
Next Next commit
Some changes
  • Loading branch information
Baroukh Ovadia committed Jun 28, 2021
commit 6f40bc10116feddd41fdb75e6a1fc27cf6ea8183
3 changes: 3 additions & 0 deletions internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
if o.Debug.DumpAST {
debug.Dump(stmt)
}
if err := validate.ParamStyle(stmt); err != nil {
return nil, err
}
lastNumber, err := validate.ParamRef(stmt)
if err != nil {
return nil, err
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions internal/endtoend/testdata/mix_param_types/mysql/go/test.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions internal/endtoend/testdata/mix_param_types/mysql/test.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ CREATE TABLE bar (
);

-- name: CountOne :one
SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> $1;
SELECT count(1) FROM bar WHERE id = sqlc.arg(id) AND name <> ?;

-- name: CountTwo :one
SELECT count(1) FROM bar WHERE id = $1 AND name <> sqlc.arg(name);
SELECT count(1) FROM bar WHERE id = ? AND name <> sqlc.arg(name);

-- name: CountThree :one
SELECT count(1) FROM bar WHERE id > $2 AND phone <> sqlc.arg(phone) AND name <> $1;
SELECT count(1) FROM bar WHERE id > ? AND phone <> sqlc.arg(phone) AND name <> ?;
49 changes: 49 additions & 0 deletions internal/sql/validate/param_style.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package validate

import (
"github.com/kyleconroy/sqlc/internal/sql/ast"
"github.com/kyleconroy/sqlc/internal/sql/astutils"
"github.com/kyleconroy/sqlc/internal/sql/named"
"github.com/kyleconroy/sqlc/internal/sql/sqlerr"
)

// A query can use one (and only one) of the following formats:
// - positional parameters $1
// - named parameter operator @param
// - named parameter function calls sqlc.arg(param)
func ParamStyle(n ast.Node) error {
namedFunc := astutils.Search(n, named.IsParamFunc)
for _, f := range namedFunc.Items {
fc, ok := f.(*ast.FuncCall)
if ok {
/*
if len(fc.Args.Items) != 1 {
return &sqlerr.Error{
Code: "", // TODO: Pick a new error code
Message: "Wrong number of arguments to sqlc.arg()",
}
}
*/
switch fc.Args.Items[0].(type) {
case *ast.FuncCall:
return &sqlerr.Error{
Code: "", // TODO: Pick a new error code
Message: "expected parameter to sqlc.arg to be string or reference; got *ast.FuncCall",
}
case *ast.ParamRef:
return &sqlerr.Error{
Code: "", // TODO: Pick a new error code
Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)",
}
case *ast.A_Const, *ast.ColumnRef:
default:
return &sqlerr.Error{
Code: "", // TODO: Pick a new error code
Message: "Invalid argument to sqlc.arg()",
}

}
}
}
return nil
}