@@ -35,6 +35,8 @@ type Conn struct {
3535// Connection read and write timeouts to set on the connection 
3636ReadTimeout  time.Duration 
3737WriteTimeout  time.Duration 
38+ contexts  chan  context.Context 
39+ closed  bool 
3840
3941// The buffer size to use in the packet connection 
4042BufferSize  int 
@@ -136,6 +138,7 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
136138c .password  =  password 
137139c .db  =  dbName 
138140c .proto  =  network 
141+ c .contexts  =  make (chan  context.Context )
139142
140143// use default charset here, utf-8 
141144c .charset  =  mysql .DEFAULT_CHARSET 
@@ -184,6 +187,21 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam
184187}
185188}
186189
190+ go  func () {
191+ ctx  :=  context .Background ()
192+ for  {
193+ var  ok  bool 
194+ select  {
195+ case  <- ctx .Done ():
196+ _  =  c .Conn .SetDeadline (time .Unix (0 , 0 ))
197+ case  ctx , ok  =  <- c .contexts :
198+ if  ! ok  {
199+ return 
200+ }
201+ }
202+ }
203+ }()
204+ 
187205return  c , nil 
188206}
189207
@@ -208,8 +226,19 @@ func (c *Conn) handshake() error {
208226return  nil 
209227}
210228
229+ func  (c  * Conn ) watchCtx (ctx  context.Context ) func () {
230+ c .contexts  <-  ctx 
231+ return  func () {
232+ c .contexts  <-  context .Background ()
233+ }
234+ }
235+ 
211236// Close directly closes the connection. Use Quit() to first send COM_QUIT to the server and then close the connection. 
212237func  (c  * Conn ) Close () error  {
238+ if  ! c .closed  {
239+ close (c .contexts )
240+ c .closed  =  true 
241+ }
213242return  c .Conn .Close ()
214243}
215244
@@ -309,6 +338,11 @@ func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, erro
309338}
310339}
311340
341+ func  (c  * Conn ) ExecuteContext (ctx  context.Context , command  string , args  ... interface {}) (* mysql.Result , error ) {
342+ defer  c .watchCtx (ctx )
343+ return  c .Execute (command , args )
344+ }
345+ 
312346// ExecuteMultiple will call perResultCallback for every result of the multiple queries 
313347// that are executed. 
314348// 
@@ -363,6 +397,11 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall
363397return  mysql .NewResult (rs ), nil 
364398}
365399
400+ func  (c  * Conn ) ExecuteMultipleContext (ctx  context.Context , query  string , perResultCallback  ExecPerResultCallback ) (* mysql.Result , error ) {
401+ defer  c .watchCtx (ctx )
402+ return  c .ExecuteMultiple (query , perResultCallback )
403+ }
404+ 
366405// ExecuteSelectStreaming will call perRowCallback for every row in resultset 
367406// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields. 
368407// When given, perResultCallback will be called once per result 
@@ -376,11 +415,33 @@ func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perR
376415return  c .readResultStreaming (false , result , perRowCallback , perResultCallback )
377416}
378417
418+ func  (c  * Conn ) ExecuteSelectStreamingContext (ctx  context.Context , command  string , result  * mysql.Result , perRowCallback  SelectPerRowCallback , perResultCallback  SelectPerResultCallback ) error  {
419+ defer  c .watchCtx (ctx )
420+ return  c .ExecuteSelectStreaming (command , result , perRowCallback , perResultCallback )
421+ }
422+ 
379423func  (c  * Conn ) Begin () error  {
380424_ , err  :=  c .exec ("BEGIN" )
381425return  errors .Trace (err )
382426}
383427
428+ func  (c  * Conn ) BeginTx (ctx  context.Context , readOnly  bool , txIsolation  string ) error  {
429+ defer  c .watchCtx (ctx )()
430+ 
431+ if  txIsolation  !=  ""  {
432+ if  _ , err  :=  c .exec ("SET TRANSACTION ISOLATION LEVEL "  +  txIsolation ); err  !=  nil  {
433+ return  errors .Trace (err )
434+ }
435+ }
436+ var  err  error 
437+ if  readOnly  {
438+ _ , err  =  c .exec ("START TRANSACTION READ ONLY" )
439+ } else  {
440+ _ , err  =  c .exec ("START TRANSACTION" )
441+ }
442+ return  errors .Trace (err )
443+ }
444+ 
384445func  (c  * Conn ) Commit () error  {
385446_ , err  :=  c .exec ("COMMIT" )
386447return  errors .Trace (err )
0 commit comments