@@ -35,6 +35,8 @@ type Conn struct {
3535// Connection read and write timeouts to set on the connection
3636ReadTimeout time.Duration
3737WriteTimeout time.Duration
38+ contexts chan context.Context
39+ closed bool
3840
3941// The buffer size to use in the packet connection
4042BufferSize int
@@ -136,6 +138,7 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
136138c .password = password
137139c .db = dbName
138140c .proto = network
141+ c .contexts = make (chan context.Context )
139142
140143// use default charset here, utf-8
141144c .charset = mysql .DEFAULT_CHARSET
@@ -184,6 +187,21 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
184187}
185188}
186189
190+ go func () {
191+ ctx := context .Background ()
192+ for {
193+ var ok bool
194+ select {
195+ case <- ctx .Done ():
196+ _ = c .Conn .SetDeadline (time .Unix (0 , 0 ))
197+ case ctx , ok = <- c .contexts :
198+ if ! ok {
199+ return
200+ }
201+ }
202+ }
203+ }()
204+
187205return c , nil
188206}
189207
@@ -208,8 +226,19 @@ func (c *Conn) handshake() error {
208226return nil
209227}
210228
229+ func (c * Conn ) watchCtx (ctx context.Context ) func () {
230+ c .contexts <- ctx
231+ return func () {
232+ c .contexts <- context .Background ()
233+ }
234+ }
235+
211236// Close directly closes the connection. Use Quit() to first send COM_QUIT to the server and then close the connection.
212237func (c * Conn ) Close () error {
238+ if ! c .closed {
239+ close (c .contexts )
240+ c .closed = true
241+ }
213242return c .Conn .Close ()
214243}
215244
@@ -309,6 +338,11 @@ func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, erro
309338}
310339}
311340
341+ func (c * Conn ) ExecuteContext (ctx context.Context , command string , args ... interface {}) (* mysql.Result , error ) {
342+ defer c .watchCtx (ctx )
343+ return c .Execute (command , args ... )
344+ }
345+
312346// ExecuteMultiple will call perResultCallback for every result of the multiple queries
313347// that are executed.
314348//
@@ -363,6 +397,11 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall
363397return mysql .NewResult (rs ), nil
364398}
365399
400+ func (c * Conn ) ExecuteMultipleContext (ctx context.Context , query string , perResultCallback ExecPerResultCallback ) (* mysql.Result , error ) {
401+ defer c .watchCtx (ctx )
402+ return c .ExecuteMultiple (query , perResultCallback )
403+ }
404+
366405// ExecuteSelectStreaming will call perRowCallback for every row in resultset
367406// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
368407// When given, perResultCallback will be called once per result
@@ -376,11 +415,33 @@ func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perR
376415return c .readResultStreaming (false , result , perRowCallback , perResultCallback )
377416}
378417
418+ func (c * Conn ) ExecuteSelectStreamingContext (ctx context.Context , command string , result * mysql.Result , perRowCallback SelectPerRowCallback , perResultCallback SelectPerResultCallback ) error {
419+ defer c .watchCtx (ctx )
420+ return c .ExecuteSelectStreaming (command , result , perRowCallback , perResultCallback )
421+ }
422+
379423func (c * Conn ) Begin () error {
380424_ , err := c .exec ("BEGIN" )
381425return errors .Trace (err )
382426}
383427
428+ func (c * Conn ) BeginTx (ctx context.Context , readOnly bool , txIsolation string ) error {
429+ defer c .watchCtx (ctx )()
430+
431+ if txIsolation != "" {
432+ if _ , err := c .exec ("SET TRANSACTION ISOLATION LEVEL " + txIsolation ); err != nil {
433+ return errors .Trace (err )
434+ }
435+ }
436+ var err error
437+ if readOnly {
438+ _ , err = c .exec ("START TRANSACTION READ ONLY" )
439+ } else {
440+ _ , err = c .exec ("START TRANSACTION" )
441+ }
442+ return errors .Trace (err )
443+ }
444+
384445func (c * Conn ) Commit () error {
385446_ , err := c .exec ("COMMIT" )
386447return errors .Trace (err )
0 commit comments