Skip to content
135 changes: 135 additions & 0 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ func (h *Handler) doQuery(
r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields)
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf)
} else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) {
r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more)
} else {
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf)
}
Expand Down Expand Up @@ -768,6 +770,139 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
return r, processedAtLeastOneBatch, nil
}

func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) {
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter2").End()

eg, ctx := ctx.NewErrgroup()
pan2err := func(err *error) {
if recoveredPanic := recover(); recoveredPanic != nil {
stack := debug.Stack()
wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack)
*err = goerrors.Join(*err, wrappedErr)
}
}

// TODO: poll for closed connections should obviously also run even if
// we're doing something with an OK result or a single row result, etc.
// This should be in the caller.
pollCtx, cancelF := ctx.NewSubContext()
eg.Go(func() (err error) {
defer pan2err(&err)
return h.pollForClosedConnection(pollCtx, c)
})

// Default waitTime is one minute if there is no timeout configured, in which case
// it will loop to iterate again unless the socket died by the OS timeout or other problems.
// If there is a timeout, it will be enforced to ensure that Vitess has a chance to
// call Handler.CloseConnection()
waitTime := 1 * time.Minute
if h.readTimeout > 0 {
waitTime = h.readTimeout
}
timer := time.NewTimer(waitTime)
defer timer.Stop()

wg := sync.WaitGroup{}
wg.Add(2)

// TODO: send results instead of rows?
// Read rows from iter and send them off
var rowChan = make(chan sql.Row2, 512)
eg.Go(func() (err error) {
defer pan2err(&err)
defer wg.Done()
defer close(rowChan)
for {
select {
case <-ctx.Done():
return context.Cause(ctx)
default:
row, err := iter.Next2(ctx)
if err == io.EOF {
return nil
}
if err != nil {
return err
}
select {
case rowChan <- row:
case <-ctx.Done():
return nil
}
}
}
})

var res *sqltypes.Result
var processedAtLeastOneBatch bool
eg.Go(func() (err error) {
defer pan2err(&err)
defer cancelF()
defer wg.Done()
for {
if res == nil {
res = &sqltypes.Result{
Fields: resultFields,
Rows: make([][]sqltypes.Value, 0, rowsBatch),
}
}
if res.RowsAffected == rowsBatch {
if err := callback(res, more); err != nil {
return err
}
res = nil
processedAtLeastOneBatch = true
continue
}

select {
case <-ctx.Done():
return context.Cause(ctx)
case <-timer.C:
if h.readTimeout != 0 {
// Cancel and return so Vitess can call the CloseConnection callback
ctx.GetLogger().Tracef("connection timeout")
return ErrRowTimeout.New()
}
case row, ok := <-rowChan:
if !ok {
return nil
}
resRow := make([]sqltypes.Value, len(row))
for i, v := range row {
resRow[i] = sqltypes.MakeTrusted(v.Typ, v.Val)
}
ctx.GetLogger().Tracef("spooling result row %s", resRow)
res.Rows = append(res.Rows, resRow)
res.RowsAffected++
if !timer.Stop() {
<-timer.C
}
}
timer.Reset(waitTime)
}
})

// Close() kills this PID in the process list,
// wait until all rows have be sent over the wire
eg.Go(func() (err error) {
defer pan2err(&err)
wg.Wait()
return iter.Close(ctx)
})

err := eg.Wait()
if err != nil {
ctx.GetLogger().WithError(err).Warn("error running query")
if verboseErrorLogging {
fmt.Printf("Err: %+v", err)
}
return nil, false, err
}

return res, processedAtLeastOneBatch, nil
}

// See https://dev.mysql.com/doc/internals/en/status-flags.html
func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error {
ok, err := isSessionAutocommit(ctx)
Expand Down
12 changes: 2 additions & 10 deletions sql/convert_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package sql
import (
"fmt"

"github.com/dolthub/vitess/go/vt/proto/query"

"github.com/dolthub/go-mysql-server/sql/values"

"github.com/dolthub/vitess/go/vt/proto/query"
)

// ConvertToValue converts the interface to a sql value.
Expand Down Expand Up @@ -90,11 +90,3 @@ func ConvertToValue(v interface{}) (Value, error) {
return Value{}, fmt.Errorf("type %T not implemented", v)
}
}

func MustConvertToValue(v interface{}) Value {
ret, err := ConvertToValue(v)
if err != nil {
panic(err)
}
return ret
}
1 change: 1 addition & 0 deletions sql/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ type Expression2 interface {
Eval2(ctx *Context, row Row2) (Value, error)
// Type2 returns the expression type.
Type2() Type2
IsExpr2() bool
}

var SystemVariables SystemVariableRegistry
Expand Down
64 changes: 64 additions & 0 deletions sql/expression/comparison.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package expression
import (
"fmt"

querypb "github.com/dolthub/vitess/go/vt/proto/query"
errors "gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/sql"
Expand Down Expand Up @@ -492,6 +493,7 @@ type GreaterThan struct {
}

var _ sql.Expression = (*GreaterThan)(nil)
var _ sql.Expression2 = (*GreaterThan)(nil)
var _ sql.CollationCoercible = (*GreaterThan)(nil)

// NewGreaterThan creates a new GreaterThan expression.
Expand All @@ -518,6 +520,68 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
return result == 1, nil
}

func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) {
l, ok := gt.Left().(sql.Expression2)
if !ok {
panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left()))
}
r, ok := gt.Right().(sql.Expression2)
if !ok {
panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Right()))
}

lv, err := l.Eval2(ctx, row)
if err != nil {
return sql.Value{}, err
}
rv, err := r.Eval2(ctx, row)
if err != nil {
return sql.Value{}, err
}

// TODO: just assume they are int64
l64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, lv)
if err != nil {
return sql.Value{}, err
}
r64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, rv)
if err != nil {
return sql.Value{}, err
}
var rb byte
if l64 > r64 {
rb = 1
}

ret := sql.Value{
Val: []byte{rb},
Typ: querypb.Type_INT8,
}
return ret, nil
}

func (gt *GreaterThan) Type2() sql.Type2 {
return nil
}

func (gt *GreaterThan) IsExpr2() bool {
lExpr, isExpr2 := gt.Left().(sql.Expression2)
if !isExpr2 {
return false
}
if !lExpr.IsExpr2() {
return false
}
rExpr, isExpr2 := gt.Right().(sql.Expression2)
if !isExpr2 {
return false
}
if !rExpr.IsExpr2() {
return false
}
return true
}

// WithChildren implements the Expression interface.
func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 2 {
Expand Down
5 changes: 4 additions & 1 deletion sql/expression/get_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,11 @@ func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) {
if p.fieldIndex < 0 || p.fieldIndex >= row.Len() {
return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len())
}
return row[p.fieldIndex], nil
}

return row.GetField(p.fieldIndex), nil
func (p *GetField) IsExpr2() bool {
return true
}

// WithChildren implements the Expression interface.
Expand Down
8 changes: 6 additions & 2 deletions sql/expression/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) {
return lit.val2, nil
}

func (lit *Literal) IsExpr2() bool {
return true
}

func (lit *Literal) Type2() sql.Type2 {
t2, ok := lit.Typ.(sql.Type2)
if !ok {
Expand All @@ -149,8 +153,8 @@ func (lit *Literal) Type2() sql.Type2 {
}

// Value returns the literal value.
func (p *Literal) Value() interface{} {
return p.Val
func (lit *Literal) Value() interface{} {
return lit.Val
}

func (lit *Literal) WithResolvedChildren(children []any) (any, error) {
Expand Down
4 changes: 4 additions & 0 deletions sql/expression/unresolved.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ func (uc *UnresolvedColumn) Type2() sql.Type2 {
panic("unresolved column is a placeholder node, but Type2 was called")
}

func (uc *UnresolvedColumn) IsExpr2() bool {
panic("unresolved column is a placeholder node, but IsExpr2 was called")
}

// Name implements the Nameable interface.
func (uc *UnresolvedColumn) Name() string { return uc.name }

Expand Down
36 changes: 36 additions & 0 deletions sql/plan/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,14 @@ func (f *Filter) Expressions() []sql.Expression {
type FilterIter struct {
cond sql.Expression
childIter sql.RowIter

cond2 sql.Expression2
childIter2 sql.RowIter2
}

var _ sql.RowIter = (*FilterIter)(nil)
var _ sql.RowIter2 = (*FilterIter)(nil)

// NewFilterIter creates a new FilterIter.
func NewFilterIter(
cond sql.Expression,
Expand Down Expand Up @@ -133,6 +139,36 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) {
}
}

func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) {
for {
row, err := i.childIter2.Next2(ctx)
if err != nil {
return nil, err
}
res, err := i.cond2.Eval2(ctx, row)
if err != nil {
return nil, err
}
if res.Val[0] == 1 {
return row, nil
}
}
}

func (i *FilterIter) IsRowIter2(ctx *sql.Context) bool {
cond, ok := i.cond.(sql.Expression2)
if !ok || !cond.IsExpr2() {
return false
}
childIter, ok := i.childIter.(sql.RowIter2)
if !ok || !childIter.IsRowIter2(ctx) {
return false
}
i.cond2 = cond
i.childIter2 = childIter
return true
}

// Close implements the RowIter interface.
func (i *FilterIter) Close(ctx *sql.Context) error {
return i.childIter.Close(ctx)
Expand Down
22 changes: 22 additions & 0 deletions sql/plan/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ const (
type TrackedRowIter struct {
node sql.Node
iter sql.RowIter
iter2 sql.RowIter2
onDone NotifyFunc
onNext NotifyFunc
numRows int64
Expand Down Expand Up @@ -317,6 +318,27 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) {
return row, nil
}

func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) {
row, err := i.iter2.Next2(ctx)
if err != nil {
return nil, err
}
i.numRows++
if i.onNext != nil {
i.onNext()
}
return row, nil
}

func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool {
iter, ok := i.iter.(sql.RowIter2)
if !ok || !iter.IsRowIter2(ctx) {
return false
}
i.iter2 = iter
return true
}

func (i *TrackedRowIter) Close(ctx *sql.Context) error {
err := i.iter.Close(ctx)

Expand Down
Loading
Loading