Skip to content

Commit 3409818

Browse files
committed
poc: context support
1 parent f7c4036 commit 3409818

File tree

3 files changed

+106
-1
lines changed

3 files changed

+106
-1
lines changed

client/conn.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ type Conn struct {
3535
// Connection read and write timeouts to set on the connection
3636
ReadTimeout time.Duration
3737
WriteTimeout time.Duration
38+
contexts chan context.Context
39+
closed bool
3840

3941
// The buffer size to use in the packet connection
4042
BufferSize int
@@ -136,6 +138,7 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
136138
c.password = password
137139
c.db = dbName
138140
c.proto = network
141+
c.contexts = make(chan context.Context)
139142

140143
// use default charset here, utf-8
141144
c.charset = mysql.DEFAULT_CHARSET
@@ -184,6 +187,21 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
184187
}
185188
}
186189

190+
go func() {
191+
ctx := context.Background()
192+
for {
193+
var ok bool
194+
select {
195+
case <-ctx.Done():
196+
_ = c.Conn.SetDeadline(time.Unix(0, 0))
197+
case ctx, ok = <-c.contexts:
198+
if !ok {
199+
return
200+
}
201+
}
202+
}
203+
}()
204+
187205
return c, nil
188206
}
189207

@@ -208,8 +226,19 @@ func (c *Conn) handshake() error {
208226
return nil
209227
}
210228

229+
func (c *Conn) watchCtx(ctx context.Context) func() {
230+
c.contexts <- ctx
231+
return func() {
232+
c.contexts <- context.Background()
233+
}
234+
}
235+
211236
// Close directly closes the connection. Use Quit() to first send COM_QUIT to the server and then close the connection.
212237
func (c *Conn) Close() error {
238+
if !c.closed {
239+
close(c.contexts)
240+
c.closed = true
241+
}
213242
return c.Conn.Close()
214243
}
215244

@@ -309,6 +338,11 @@ func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, erro
309338
}
310339
}
311340

341+
func (c *Conn) ExecuteContext(ctx context.Context, command string, args ...interface{}) (*mysql.Result, error) {
342+
defer c.watchCtx(ctx)
343+
return c.Execute(command, args...)
344+
}
345+
312346
// ExecuteMultiple will call perResultCallback for every result of the multiple queries
313347
// that are executed.
314348
//
@@ -363,6 +397,11 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall
363397
return mysql.NewResult(rs), nil
364398
}
365399

400+
func (c *Conn) ExecuteMultipleContext(ctx context.Context, query string, perResultCallback ExecPerResultCallback) (*mysql.Result, error) {
401+
defer c.watchCtx(ctx)
402+
return c.ExecuteMultiple(query, perResultCallback)
403+
}
404+
366405
// ExecuteSelectStreaming will call perRowCallback for every row in resultset
367406
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
368407
// When given, perResultCallback will be called once per result
@@ -376,11 +415,33 @@ func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perR
376415
return c.readResultStreaming(false, result, perRowCallback, perResultCallback)
377416
}
378417

418+
func (c *Conn) ExecuteSelectStreamingContext(ctx context.Context, command string, result *mysql.Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error {
419+
defer c.watchCtx(ctx)
420+
return c.ExecuteSelectStreaming(command, result, perRowCallback, perResultCallback)
421+
}
422+
379423
func (c *Conn) Begin() error {
380424
_, err := c.exec("BEGIN")
381425
return errors.Trace(err)
382426
}
383427

428+
func (c *Conn) BeginTx(ctx context.Context, readOnly bool, txIsolation string) error {
429+
defer c.watchCtx(ctx)()
430+
431+
if txIsolation != "" {
432+
if _, err := c.exec("SET TRANSACTION ISOLATION LEVEL " + txIsolation); err != nil {
433+
return errors.Trace(err)
434+
}
435+
}
436+
var err error
437+
if readOnly {
438+
_, err = c.exec("START TRANSACTION READ ONLY")
439+
} else {
440+
_, err = c.exec("START TRANSACTION")
441+
}
442+
return errors.Trace(err)
443+
}
444+
384445
func (c *Conn) Commit() error {
385446
_, err := c.exec("COMMIT")
386447
return errors.Trace(err)

driver/driver.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package driver
44

55
import (
6+
"context"
67
"crypto/tls"
78
"database/sql"
89
sqldriver "database/sql/driver"
@@ -185,6 +186,9 @@ type CheckNamedValueFunc func(*sqldriver.NamedValue) error
185186
var (
186187
_ sqldriver.NamedValueChecker = &conn{}
187188
_ sqldriver.Validator = &conn{}
189+
_ sqldriver.Conn = &conn{}
190+
_ sqldriver.ConnBeginTx = &conn{}
191+
_ sqldriver.QueryerContext = &conn{}
188192
)
189193

190194
type state struct {
@@ -242,6 +246,27 @@ func (c *conn) Begin() (sqldriver.Tx, error) {
242246
return &tx{c.Conn}, nil
243247
}
244248

249+
var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{
250+
sql.LevelDefault: "",
251+
sql.LevelRepeatableRead: "REPEATABLE READ",
252+
sql.LevelReadCommitted: "READ COMMITTED",
253+
sql.LevelReadUncommitted: "READ UNCOMMITTED",
254+
sql.LevelSerializable: "SERIALIZABLE",
255+
}
256+
257+
func (c *conn) BeginTx(ctx context.Context, opts sqldriver.TxOptions) (sqldriver.Tx, error) {
258+
isolation := sql.IsolationLevel(opts.Isolation)
259+
txIsolation, ok := isolationLevelTransactionIsolation[isolation]
260+
if !ok {
261+
return nil, fmt.Errorf("invalid mysql transaction isolation level %s", isolation)
262+
}
263+
err := c.Conn.BeginTx(ctx, opts.ReadOnly, txIsolation)
264+
if err != nil {
265+
return nil, errors.Trace(err)
266+
}
267+
return &tx{c.Conn}, nil
268+
}
269+
245270
func buildArgs(args []sqldriver.Value) []interface{} {
246271
a := make([]interface{}, len(args))
247272

@@ -252,6 +277,17 @@ func buildArgs(args []sqldriver.Value) []interface{} {
252277
return a
253278
}
254279

280+
func buildNamedArgs(args []sqldriver.NamedValue) []interface{} {
281+
a := make([]interface{}, len(args))
282+
283+
for i, arg := range args {
284+
// TODO named parameter support
285+
a[i] = arg.Value
286+
}
287+
288+
return a
289+
}
290+
255291
func (st *state) replyError(err error) error {
256292
isBadConnection := mysql.ErrorEqual(err, mysql.ErrBadConn)
257293

@@ -284,6 +320,15 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
284320
return newRows(r.Resultset)
285321
}
286322

323+
func (c *conn) QueryContext(ctx context.Context, query string, args []sqldriver.NamedValue) (sqldriver.Rows, error) {
324+
a := buildNamedArgs(args)
325+
r, err := c.Conn.ExecuteContext(ctx, query, a...)
326+
if err != nil {
327+
return nil, c.state.replyError(err)
328+
}
329+
return newRows(r.Resultset)
330+
}
331+
287332
type stmt struct {
288333
*client.Stmt
289334
connectionState *state

packet/conn.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ type Conn struct {
3131

3232
readTimeout time.Duration
3333
writeTimeout time.Duration
34-
ctx context.Context
3534

3635
// Buffered reader for net.Conn in Non-TLS connection only to address replication performance issue.
3736
// See https://github.com/go-mysql-org/go-mysql/pull/422 for more details.

0 commit comments

Comments
 (0)