33package  driver
44
55import  (
6+ "context" 
67"crypto/tls" 
78"database/sql" 
89sqldriver "database/sql/driver" 
@@ -21,6 +22,23 @@ import (
2122"github.com/pingcap/errors" 
2223)
2324
25+ var  (
26+ _  sqldriver.Driver  =  & driver {}
27+ _  sqldriver.DriverContext  =  & driver {}
28+ _  sqldriver.Connector  =  & connInfo {}
29+ _  sqldriver.NamedValueChecker  =  & conn {}
30+ _  sqldriver.Validator  =  & conn {}
31+ _  sqldriver.Conn  =  & conn {}
32+ _  sqldriver.Pinger  =  & conn {}
33+ _  sqldriver.ConnBeginTx  =  & conn {}
34+ _  sqldriver.ConnPrepareContext  =  & conn {}
35+ _  sqldriver.ExecerContext  =  & conn {}
36+ _  sqldriver.QueryerContext  =  & conn {}
37+ _  sqldriver.Stmt  =  & stmt {}
38+ _  sqldriver.StmtExecContext  =  & stmt {}
39+ _  sqldriver.StmtQueryContext  =  & stmt {}
40+ )
41+ 
2442var  customTLSMutex  sync.Mutex 
2543
2644// Map of dsn address (makes more sense than full dsn?) to tls Config 
@@ -101,16 +119,18 @@ func parseDSN(dsn string) (connInfo, error) {
101119// Open takes a supplied DSN string and opens a connection 
102120// See ParseDSN for more information on the form of the DSN 
103121func  (d  driver ) Open (dsn  string ) (sqldriver.Conn , error ) {
104- var  (
105- c  * client.Conn 
106- // by default database/sql driver retries will be enabled 
107- retries  =  true 
108- )
109- 
110122ci , err  :=  parseDSN (dsn )
111123if  err  !=  nil  {
112124return  nil , err 
113125}
126+ return  ci .Connect (context .Background ())
127+ }
128+ 
129+ func  (ci  connInfo ) Connect (ctx  context.Context ) (sqldriver.Conn , error ) {
130+ var  c  * client.Conn 
131+ var  err  error 
132+ // by default database/sql driver retries will be enabled 
133+ retries  :=  true 
114134
115135if  ci .standardDSN  {
116136var  timeout  time.Duration 
@@ -159,45 +179,86 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
159179}
160180}
161181
162- if  timeout  >  0  {
163- c , err  =  client .ConnectWithTimeout (ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
164- } else  {
165- c , err  =  client .Connect (ci .addr , ci .user , ci .password , ci .db , configuredOptions ... )
182+ if  timeout  <=  0  {
183+ timeout  =  10  *  time .Second 
166184}
185+ c , err  =  client .ConnectWithContext (ctx , ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
167186} else  {
168187// No more processing here. Let's only support url parameters with the newer style DSN 
169- c , err  =  client .Connect ( ci .addr , ci .user , ci .password , ci .db )
188+ c , err  =  client .ConnectWithContext ( ctx ,  ci .addr , ci .user , ci .password , ci .db ,  10 * time . Second )
170189}
171190if  err  !=  nil  {
172191return  nil , err 
173192}
174193
194+ contexts  :=  make (chan  context.Context )
195+ go  func () {
196+ ctx  :=  context .Background ()
197+ for  {
198+ var  ok  bool 
199+ select  {
200+ case  <- ctx .Done ():
201+ ctx  =  context .Background ()
202+ _  =  c .Conn .Close ()
203+ case  ctx , ok  =  <- contexts :
204+ if  ! ok  {
205+ return 
206+ }
207+ }
208+ }
209+ }()
210+ 
175211// if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3 
176212// retries by the database/sql package. If retries are 'off' then we'll return 
177213// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry. 
178214// In this case the sqldriver.Validator interface is implemented and will return 
179215// false for IsValid() signaling the connection is bad and should be discarded. 
180- return  & conn {Conn : c , state : & state {valid : true , useStdLibErrors : retries }}, nil 
216+ return  & conn {
217+ Conn : c ,
218+ state : & state {contexts : contexts , valid : true , useStdLibErrors : retries },
219+ }, nil 
181220}
182221
183- type  CheckNamedValueFunc  func (* sqldriver.NamedValue ) error 
222+ func  (d  driver ) OpenConnector (name  string ) (sqldriver.Connector , error ) {
223+ return  parseDSN (name )
224+ }
184225
185- var  (
186- _  sqldriver.NamedValueChecker  =  & conn {}
187- _  sqldriver.Validator  =  & conn {}
188- )
226+ func  (ci  connInfo ) Driver () sqldriver.Driver  {
227+ return  driver {}
228+ }
229+ 
230+ type  CheckNamedValueFunc  func (* sqldriver.NamedValue ) error 
189231
190232type  state  struct  {
191- valid  bool 
233+ contexts  chan  context.Context 
234+ valid  bool 
192235// when true, the driver connection will return ErrBadConn from the golang Standard Library 
193236useStdLibErrors  bool 
194237}
195238
239+ func  (s  * state ) watchCtx (ctx  context.Context ) func () {
240+ s .contexts  <-  ctx 
241+ return  func () {
242+ s .contexts  <-  context .Background ()
243+ }
244+ }
245+ 
246+ func  (s  * state ) Close () {
247+ if  s .contexts  !=  nil  {
248+ close (s .contexts )
249+ s .contexts  =  nil 
250+ }
251+ }
252+ 
196253type  conn  struct  {
197254* client.Conn 
198255state  * state 
199256}
200257
258+ func  (c  * conn ) watchCtx (ctx  context.Context ) func () {
259+ return  c .state .watchCtx (ctx )
260+ }
261+ 
201262func  (c  * conn ) CheckNamedValue (nv  * sqldriver.NamedValue ) error  {
202263for  _ , nvChecker  :=  range  namedValueCheckers  {
203264err  :=  nvChecker (nv )
@@ -220,6 +281,17 @@ func (c *conn) IsValid() bool {
220281return  c .state .valid 
221282}
222283
284+ func  (c  * conn ) Ping (ctx  context.Context ) error  {
285+ defer  c .watchCtx (ctx )()
286+ if  err  :=  c .Conn .Ping (); err  !=  nil  {
287+ if  err  ==  context .DeadlineExceeded  ||  err  ==  context .Canceled  {
288+ return  err 
289+ }
290+ return  sqldriver .ErrBadConn 
291+ }
292+ return  nil 
293+ }
294+ 
223295func  (c  * conn ) Prepare (query  string ) (sqldriver.Stmt , error ) {
224296st , err  :=  c .Conn .Prepare (query )
225297if  err  !=  nil  {
@@ -229,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
229301return  & stmt {Stmt : st , connectionState : c .state }, nil 
230302}
231303
304+ func  (c  * conn ) PrepareContext (ctx  context.Context , query  string ) (sqldriver.Stmt , error ) {
305+ defer  c .watchCtx (ctx )()
306+ return  c .Prepare (query )
307+ }
308+ 
232309func  (c  * conn ) Close () error  {
310+ c .state .Close ()
233311return  c .Conn .Close ()
234312}
235313
@@ -242,6 +320,29 @@ func (c *conn) Begin() (sqldriver.Tx, error) {
242320return  & tx {c .Conn }, nil 
243321}
244322
323+ var  isolationLevelTransactionIsolation  =  map [sql.IsolationLevel ]string {
324+ sql .LevelDefault : "" ,
325+ sql .LevelRepeatableRead : "REPEATABLE READ" ,
326+ sql .LevelReadCommitted : "READ COMMITTED" ,
327+ sql .LevelReadUncommitted : "READ UNCOMMITTED" ,
328+ sql .LevelSerializable : "SERIALIZABLE" ,
329+ }
330+ 
331+ func  (c  * conn ) BeginTx (ctx  context.Context , opts  sqldriver.TxOptions ) (sqldriver.Tx , error ) {
332+ defer  c .watchCtx (ctx )()
333+ 
334+ isolation  :=  sql .IsolationLevel (opts .Isolation )
335+ txIsolation , ok  :=  isolationLevelTransactionIsolation [isolation ]
336+ if  ! ok  {
337+ return  nil , fmt .Errorf ("invalid mysql transaction isolation level %s" , isolation )
338+ }
339+ err  :=  c .Conn .BeginTx (opts .ReadOnly , txIsolation )
340+ if  err  !=  nil  {
341+ return  nil , errors .Trace (err )
342+ }
343+ return  & tx {c .Conn }, nil 
344+ }
345+ 
245346func  buildArgs (args  []sqldriver.Value ) []interface {} {
246347a  :=  make ([]interface {}, len (args ))
247348
@@ -252,6 +353,16 @@ func buildArgs(args []sqldriver.Value) []interface{} {
252353return  a 
253354}
254355
356+ func  buildNamedArgs (args  []sqldriver.NamedValue ) []interface {} {
357+ a  :=  make ([]interface {}, len (args ))
358+ 
359+ for  i , arg  :=  range  args  {
360+ a [i ] =  arg .Value 
361+ }
362+ 
363+ return  a 
364+ }
365+ 
255366func  (st  * state ) replyError (err  error ) error  {
256367isBadConnection  :=  mysql .ErrorEqual (err , mysql .ErrBadConn )
257368
@@ -275,6 +386,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
275386return  & result {r }, nil 
276387}
277388
389+ func  (c  * conn ) ExecContext (ctx  context.Context , query  string , args  []sqldriver.NamedValue ) (sqldriver.Result , error ) {
390+ defer  c .watchCtx (ctx )()
391+ a  :=  buildNamedArgs (args )
392+ r , err  :=  c .Conn .Execute (query , a ... )
393+ if  err  !=  nil  {
394+ return  nil , c .state .replyError (err )
395+ }
396+ return  & result {r }, nil 
397+ }
398+ 
278399func  (c  * conn ) Query (query  string , args  []sqldriver.Value ) (sqldriver.Rows , error ) {
279400a  :=  buildArgs (args )
280401r , err  :=  c .Conn .Execute (query , a ... )
@@ -284,11 +405,25 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
284405return  newRows (r .Resultset )
285406}
286407
408+ func  (c  * conn ) QueryContext (ctx  context.Context , query  string , args  []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
409+ defer  c .watchCtx (ctx )()
410+ a  :=  buildNamedArgs (args )
411+ r , err  :=  c .Conn .Execute (query , a ... )
412+ if  err  !=  nil  {
413+ return  nil , c .state .replyError (err )
414+ }
415+ return  newRows (r .Resultset )
416+ }
417+ 
287418type  stmt  struct  {
288419* client.Stmt 
289420connectionState  * state 
290421}
291422
423+ func  (s  * stmt ) watchCtx (ctx  context.Context ) func () {
424+ return  s .connectionState .watchCtx (ctx )
425+ }
426+ 
292427func  (s  * stmt ) Close () error  {
293428return  s .Stmt .Close ()
294429}
@@ -306,6 +441,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
306441return  & result {r }, nil 
307442}
308443
444+ func  (s  * stmt ) ExecContext (ctx  context.Context , args  []sqldriver.NamedValue ) (sqldriver.Result , error ) {
445+ defer  s .watchCtx (ctx )()
446+ 
447+ a  :=  buildNamedArgs (args )
448+ r , err  :=  s .Stmt .Execute (a ... )
449+ if  err  !=  nil  {
450+ return  nil , s .connectionState .replyError (err )
451+ }
452+ return  & result {r }, nil 
453+ }
454+ 
309455func  (s  * stmt ) Query (args  []sqldriver.Value ) (sqldriver.Rows , error ) {
310456a  :=  buildArgs (args )
311457r , err  :=  s .Stmt .Execute (a ... )
@@ -315,6 +461,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
315461return  newRows (r .Resultset )
316462}
317463
464+ func  (s  * stmt ) QueryContext (ctx  context.Context , args  []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
465+ defer  s .watchCtx (ctx )()
466+ 
467+ a  :=  buildNamedArgs (args )
468+ r , err  :=  s .Stmt .Execute (a ... )
469+ if  err  !=  nil  {
470+ return  nil , s .connectionState .replyError (err )
471+ }
472+ return  newRows (r .Resultset )
473+ }
474+ 
318475type  tx  struct  {
319476* client.Conn 
320477}
0 commit comments