Skip to content
23 changes: 12 additions & 11 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,18 @@ func pluginQueryColumn(c *compiler.Column) *plugin.Column {
l = *c.Length
}
out := &plugin.Column{
Name: c.Name,
OriginalName: c.OriginalName,
Comment: c.Comment,
NotNull: c.NotNull,
Unsigned: c.Unsigned,
IsArray: c.IsArray,
ArrayDims: int32(c.ArrayDims),
Length: int32(l),
IsNamedParam: c.IsNamedParam,
IsFuncCall: c.IsFuncCall,
IsSqlcSlice: c.IsSqlcSlice,
Name: c.Name,
OriginalName: c.OriginalName,
Comment: c.Comment,
NotNull: c.NotNull,
Unsigned: c.Unsigned,
IsArray: c.IsArray,
ArrayDims: int32(c.ArrayDims),
Length: int32(l),
IsNamedParam: c.IsNamedParam,
IsFuncCall: c.IsFuncCall,
IsSqlcSlice: c.IsSqlcSlice,
IsSqlcDynamic: c.IsSqlcDynamic,
}

if c.Type != nil {
Expand Down
4 changes: 4 additions & 0 deletions internal/codegen/golang/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func (gf Field) HasSqlcSlice() bool {
return gf.Column.IsSqlcSlice
}

func (gf Field) HasSqlcDynamic() bool {
return gf.Column.IsSqlcDynamic
}

func TagsToString(tags map[string]string) string {
if len(tags) == 0 {
return ""
Expand Down
17 changes: 16 additions & 1 deletion internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"text/template"

"github.com/sqlc-dev/sqlc/internal/codegen/sdk"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/metadata"
"github.com/sqlc-dev/sqlc/internal/plugin"
)
Expand Down Expand Up @@ -38,6 +39,7 @@ type tmplCtx struct {
EmitAllEnumValues bool
UsesCopyFrom bool
UsesBatch bool
HasSqlcDynamic bool
BuildTags string
}

Expand Down Expand Up @@ -130,6 +132,13 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [
Enums: enums,
Structs: structs,
}
var hasDynamic bool
for _, q := range queries {
if q.Arg.HasSqlcDynamic() {
hasDynamic = true
break
}
}

tctx := tmplCtx{
EmitInterface: options.EmitInterface,
Expand All @@ -148,8 +157,8 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [
Package: options.Package,
Enums: enums,
Structs: structs,
SqlcVersion: req.SqlcVersion,
BuildTags: options.BuildTags,
SqlcVersion: req.SqlcVersion,
}

if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != SQLDriverGoSQLDriverMySQL {
Expand Down Expand Up @@ -180,6 +189,12 @@ func generate(req *plugin.CodeGenRequest, options *opts, enums []Enum, structs [
"emitPreparedQueries": tctx.codegenEmitPreparedQueries,
"queryMethod": tctx.codegenQueryMethod,
"queryRetval": tctx.codegenQueryRetval,
"dollar": func() bool {
return req.Settings.Engine == string(config.EnginePostgreSQL)
},
"hasDynamic": func() bool {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure how to add these values without putting them here. So while it does work it is kind of a hack.

return hasDynamic
},
}

tmpl := template.Must(
Expand Down
3 changes: 3 additions & 0 deletions internal/codegen/golang/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, co

func goType(req *plugin.CodeGenRequest, options *opts, col *plugin.Column) string {
// Check if the column's type has been overridden
if col.IsSqlcDynamic {
return "DynamicSql"
}
for _, oride := range req.Settings.Overrides {
if oride.GoType.TypeName == "" {
continue
Expand Down
6 changes: 5 additions & 1 deletion internal/codegen/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,14 @@ func sortedImports(std map[string]struct{}, pkg map[ImportSpec]struct{}) fileImp
func (i *importer) queryImports(filename string) fileImports {
var gq []Query
anyNonCopyFrom := false
useStrings := false
for _, query := range i.Queries {
if usesBatch([]Query{query}) {
continue
}
if query.Arg.HasSqlcDynamic() {
useStrings = true
}
if query.SourceName == filename {
gq = append(gq, query)
if query.Cmd != metadata.CmdCopyFrom {
Expand Down Expand Up @@ -384,7 +388,7 @@ func (i *importer) queryImports(filename string) fileImports {
}

sqlpkg := parseDriver(i.Options.SqlPackage)
if sqlcSliceScan() {
if useStrings || sqlcSliceScan() {
std["strings"] = struct{}{}
}
if sliceScan() && !sqlpkg.IsPGX() {
Expand Down
3 changes: 3 additions & 0 deletions internal/codegen/golang/mysql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ func mysqlType(req *plugin.CodeGenRequest, col *plugin.Column) string {
columnType := sdk.DataType(col.Type)
notNull := col.NotNull || col.IsArray
unsigned := col.Unsigned
if col.IsSqlcDynamic {
return "DynamicSql"
}

switch columnType {

Expand Down
3 changes: 3 additions & 0 deletions internal/codegen/golang/postgresql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ func parseIdentifierString(name string) (*plugin.Identifier, error) {
}

func postgresType(req *plugin.CodeGenRequest, options *opts, col *plugin.Column) string {
if col.IsSqlcDynamic {
return "DynamicSql"
}
columnType := sdk.DataType(col.Type)
notNull := col.NotNull || col.IsArray
driver := parseDriver(options.SqlPackage)
Expand Down
32 changes: 30 additions & 2 deletions internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,16 @@ func (v QueryValue) Params() string {
}
var out []string
if v.Struct == nil {
if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() {
if v.Column.IsSqlcDynamic {
} else if !v.Column.IsSqlcSlice && strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() {
out = append(out, "pq.Array("+escape(v.Name)+")")
} else {
out = append(out, escape(v.Name))
}
} else {
for _, f := range v.Struct.Fields {
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
if f.HasSqlcDynamic() {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ovadbar there is no logic inside this condition. May be it is not required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I have should have written it as the following. The code is needed

if !f.HasSqlcDynamic() { if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {	out = append(out, "pq.Array("+escape(v.VariableForField(f))+")") } else {	out = append(out, escape(v.VariableForField(f))) } } 

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if f.HasSqlcDynamic() {
if !f.HasSqlcDynamic() {
if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
out = append(out, "pq.Array("+escape(v.VariableForField(f))+")")
} else {
out = append(out, escape(v.VariableForField(f)))
}
}
} else if !f.HasSqlcSlice() && strings.HasPrefix(f.Type, "[]") && f.Type != "[]byte" && !v.SQLDriver.IsPGX() {
out = append(out, "pq.Array("+escape(v.VariableForField(f))+")")
} else {
out = append(out, escape(v.VariableForField(f)))
Expand Down Expand Up @@ -188,6 +190,32 @@ func (v QueryValue) HasSqlcSlices() bool {
}
return false
}
func (v QueryValue) HasSqlcDynamic() bool {
if v.Struct == nil {
if v.Column != nil && v.Column.IsSqlcDynamic {
return true
}
return false
}
for _, v := range v.Struct.Fields {
if v.Column.IsSqlcDynamic {
return true
}
}
return false
}
func (v QueryValue) SqlcDynamic() int {
var count int = 1
if v.Struct == nil {
return 1
}
for _, v := range v.Struct.Fields {
if !v.Column.IsSqlcDynamic {
count++
}
}
return count
}

func (v QueryValue) Scan() string {
var out []string
Expand Down
5 changes: 5 additions & 0 deletions internal/codegen/golang/templates/pgx/dbCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ type DBTX interface {
{{- end }}
}

{{- if hasDynamic }}
type DynamicSql interface {
ToSql(int) (string, []interface{})
}
{{- end}}
{{ if .EmitMethodsWithDBArgument}}
func New() *Queries {
return &Queries{}
Expand Down
73 changes: 73 additions & 0 deletions internal/codegen/golang/templates/pgx/queryCode.tmpl
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
{{define "preexec"}}
{{- if .Arg.Struct }}
queryParams := []interface{}{ {{.Arg.Params}} }
{{- $arg := .Arg }}
curNumb := {{ $arg.SqlcDynamic }}
query := {{.ConstantName}}
var replaceText string
var args []interface{}
{{- range .Arg.Struct.Fields }}
{{- if .HasSqlcDynamic }}
replaceText, args = {{$arg.VariableForField .}}.ToSql(curNumb)
curNumb += len(args)
query = strings.ReplaceAll(query, "/*DYNAMIC:{{.Column.Name}}*/$1", replaceText)
queryParams = append(queryParams, args...)
{{- end}}
{{- end}}
{{- else}}
replaceText, queryParams := {{.Arg.Column.Name}}.ToSql(1)
query := strings.ReplaceAll({{.ConstantName}}, "/*DYNAMIC:{{.Arg.Column.Name}}*/$1", replaceText)
{{- end}}
{{- end}}

{{define "queryCodePgx"}}
{{range .GoQueries}}
{{if $.OutputQuery .SourceName}}
Expand Down Expand Up @@ -28,10 +50,20 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}}
{{end -}}
{{- if $.EmitMethodsWithDBArgument -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
row := db.QueryRow(ctx, query, queryParams...)
{{- else}}
row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
row := q.db.QueryRow(ctx, query, queryParams...)
{{- else}}
row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- end}}
{{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }}
var {{.Ret.Name}} {{.Ret.Type}}
Expand All @@ -46,10 +78,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.De
{{end -}}
{{- if $.EmitMethodsWithDBArgument -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
rows, err := db.Query(ctx, query, queryParams...)
{{- else}}
rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
rows, err := q.db.Query(ctx, query, queryParams...)
{{- else}}
rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- end}}
if err != nil {
return nil, err
Expand Down Expand Up @@ -79,10 +121,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.
{{end -}}
{{- if $.EmitMethodsWithDBArgument -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
_, err := db.Exec(ctx, query, queryParams...)
{{- else}}
_, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
_, err := q.db.Exec(ctx, query, queryParams...)
{{- else}}
_, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- end}}
return err
}
Expand All @@ -93,10 +145,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error {
{{end -}}
{{if $.EmitMethodsWithDBArgument -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
result, err := db.Exec(ctx, query, queryParams...)
{{- else}}
result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
result, err := q.db.Exec(ctx, query, queryParams...)
{{- else}}
result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- end}}
if err != nil {
return 0, err
Expand All @@ -110,10 +172,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er
{{end -}}
{{- if $.EmitMethodsWithDBArgument -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
return db.Exec(ctx, query, queryParams...)
{{- else}}
return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- else -}}
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) {
{{- if .Arg.HasSqlcDynamic }}
{{- template "preexec" .}}
return q.db.Exec(ctx, query, queryParams...)
{{- else}}
return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}})
{{- end}}
{{- end}}
}
{{end}}
Expand All @@ -122,3 +194,4 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.Co
{{end}}
{{end}}
{{end}}

6 changes: 6 additions & 0 deletions internal/codegen/golang/templates/stdlib/dbCode.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ type DBTX interface {
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}

{{- if hasDynamic }}
type DynamicSql interface {
ToSql({{ if dollar}}int{{ end }}) (string, []interface{})
}
{{- end}}

{{ if .EmitMethodsWithDBArgument}}
func New() *Queries {
return &Queries{}
Expand Down
Loading