@@ -22,6 +22,23 @@ import (
2222"github.com/pingcap/errors"
2323)
2424
25+ var (
26+ _ sqldriver.Driver = & driver {}
27+ _ sqldriver.DriverContext = & driver {}
28+ _ sqldriver.Connector = & connInfo {}
29+ _ sqldriver.NamedValueChecker = & conn {}
30+ _ sqldriver.Validator = & conn {}
31+ _ sqldriver.Conn = & conn {}
32+ _ sqldriver.Pinger = & conn {}
33+ _ sqldriver.ConnBeginTx = & conn {}
34+ _ sqldriver.ConnPrepareContext = & conn {}
35+ _ sqldriver.ExecerContext = & conn {}
36+ _ sqldriver.QueryerContext = & conn {}
37+ _ sqldriver.Stmt = & stmt {}
38+ _ sqldriver.StmtExecContext = & stmt {}
39+ _ sqldriver.StmtQueryContext = & stmt {}
40+ )
41+
2542var customTLSMutex sync.Mutex
2643
2744// Map of dsn address (makes more sense than full dsn?) to tls Config
@@ -102,16 +119,19 @@ func parseDSN(dsn string) (connInfo, error) {
102119// Open takes a supplied DSN string and opens a connection
103120// See ParseDSN for more information on the form of the DSN
104121func (d driver ) Open (dsn string ) (sqldriver.Conn , error ) {
105- var (
106- c * client.Conn
107- // by default database/sql driver retries will be enabled
108- retries = true
109- )
110-
111122ci , err := parseDSN (dsn )
112123if err != nil {
113124return nil , err
114125}
126+ return ci .Connect (context .Background ())
127+
128+ }
129+
130+ func (ci connInfo ) Connect (ctx context.Context ) (sqldriver.Conn , error ) {
131+ var c * client.Conn
132+ var err error
133+ // by default database/sql driver retries will be enabled
134+ retries := true
115135
116136if ci .standardDSN {
117137var timeout time.Duration
@@ -160,48 +180,85 @@ func (d driver) Open(dsn string) (sqldriver.Conn, error) {
160180}
161181}
162182
163- if timeout > 0 {
164- c , err = client .ConnectWithTimeout (ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
165- } else {
166- c , err = client .Connect (ci .addr , ci .user , ci .password , ci .db , configuredOptions ... )
183+ if timeout <= 0 {
184+ timeout = 10 * time .Second
167185}
186+ c , err = client .ConnectWithContext (ctx , ci .addr , ci .user , ci .password , ci .db , timeout , configuredOptions ... )
168187} else {
169188// No more processing here. Let's only support url parameters with the newer style DSN
170- c , err = client .Connect ( ci .addr , ci .user , ci .password , ci .db )
189+ c , err = client .ConnectWithContext ( ctx , ci .addr , ci .user , ci .password , ci .db , 10 * time . Second )
171190}
172191if err != nil {
173192return nil , err
174193}
175194
195+ contexts := make (chan context.Context )
196+ go func () {
197+ ctx := context .Background ()
198+ for {
199+ var ok bool
200+ select {
201+ case <- ctx .Done ():
202+ _ = c .Conn .Close ()
203+ case ctx , ok = <- contexts :
204+ if ! ok {
205+ return
206+ }
207+ }
208+ }
209+ }()
210+
176211// if retries are 'on' then return sqldriver.ErrBadConn which will trigger up to 3
177212// retries by the database/sql package. If retries are 'off' then we'll return
178213// the native go-mysql-org/go-mysql 'mysql.ErrBadConn' erorr which will prevent a retry.
179214// In this case the sqldriver.Validator interface is implemented and will return
180215// false for IsValid() signaling the connection is bad and should be discarded.
181- return & conn {Conn : c , state : & state {valid : true , useStdLibErrors : retries }}, nil
216+ return & conn {
217+ Conn : c ,
218+ state : & state {contexts : contexts , valid : true , useStdLibErrors : retries },
219+ }, nil
182220}
183221
184- type CheckNamedValueFunc func (* sqldriver.NamedValue ) error
222+ func (d driver ) OpenConnector (name string ) (sqldriver.Connector , error ) {
223+ return parseDSN (name )
224+ }
185225
186- var (
187- _ sqldriver.NamedValueChecker = & conn {}
188- _ sqldriver.Validator = & conn {}
189- _ sqldriver.Conn = & conn {}
190- _ sqldriver.ConnBeginTx = & conn {}
191- _ sqldriver.QueryerContext = & conn {}
192- )
226+ func (ci connInfo ) Driver () sqldriver.Driver {
227+ return driver {}
228+ }
229+
230+ type CheckNamedValueFunc func (* sqldriver.NamedValue ) error
193231
194232type state struct {
195- valid bool
233+ contexts chan context.Context
234+ valid bool
196235// when true, the driver connection will return ErrBadConn from the golang Standard Library
197236useStdLibErrors bool
198237}
199238
239+ func (s * state ) watchCtx (ctx context.Context ) func () {
240+ s .contexts <- ctx
241+ return func () {
242+ s .contexts <- context .Background ()
243+ }
244+ }
245+
246+ func (s * state ) Close () {
247+ if s .contexts != nil {
248+ close (s .contexts )
249+ s .contexts = nil
250+ }
251+ }
252+
200253type conn struct {
201254* client.Conn
202255state * state
203256}
204257
258+ func (c * conn ) watchCtx (ctx context.Context ) func () {
259+ return c .state .watchCtx (ctx )
260+ }
261+
205262func (c * conn ) CheckNamedValue (nv * sqldriver.NamedValue ) error {
206263for _ , nvChecker := range namedValueCheckers {
207264err := nvChecker (nv )
@@ -224,6 +281,17 @@ func (c *conn) IsValid() bool {
224281return c .state .valid
225282}
226283
284+ func (c * conn ) Ping (ctx context.Context ) error {
285+ defer c .watchCtx (ctx )()
286+ if err := c .Conn .Ping (); err != nil {
287+ if err == context .DeadlineExceeded || err == context .Canceled {
288+ return err
289+ }
290+ return sqldriver .ErrBadConn
291+ }
292+ return nil
293+ }
294+
227295func (c * conn ) Prepare (query string ) (sqldriver.Stmt , error ) {
228296st , err := c .Conn .Prepare (query )
229297if err != nil {
@@ -233,7 +301,13 @@ func (c *conn) Prepare(query string) (sqldriver.Stmt, error) {
233301return & stmt {Stmt : st , connectionState : c .state }, nil
234302}
235303
304+ func (c * conn ) PrepareContext (ctx context.Context , query string ) (sqldriver.Stmt , error ) {
305+ defer c .watchCtx (ctx )()
306+ return c .Prepare (query )
307+ }
308+
236309func (c * conn ) Close () error {
310+ c .state .Close ()
237311return c .Conn .Close ()
238312}
239313
@@ -255,12 +329,14 @@ var isolationLevelTransactionIsolation = map[sql.IsolationLevel]string{
255329}
256330
257331func (c * conn ) BeginTx (ctx context.Context , opts sqldriver.TxOptions ) (sqldriver.Tx , error ) {
332+ defer c .watchCtx (ctx )()
333+
258334isolation := sql .IsolationLevel (opts .Isolation )
259335txIsolation , ok := isolationLevelTransactionIsolation [isolation ]
260336if ! ok {
261337return nil , fmt .Errorf ("invalid mysql transaction isolation level %s" , isolation )
262338}
263- err := c .Conn .BeginTx (ctx , opts .ReadOnly , txIsolation )
339+ err := c .Conn .BeginTx (opts .ReadOnly , txIsolation )
264340if err != nil {
265341return nil , errors .Trace (err )
266342}
@@ -311,6 +387,16 @@ func (c *conn) Exec(query string, args []sqldriver.Value) (sqldriver.Result, err
311387return & result {r }, nil
312388}
313389
390+ func (c * conn ) ExecContext (ctx context.Context , query string , args []sqldriver.NamedValue ) (sqldriver.Result , error ) {
391+ defer c .watchCtx (ctx )()
392+ a := buildNamedArgs (args )
393+ r , err := c .Conn .Execute (query , a ... )
394+ if err != nil {
395+ return nil , c .state .replyError (err )
396+ }
397+ return & result {r }, nil
398+ }
399+
314400func (c * conn ) Query (query string , args []sqldriver.Value ) (sqldriver.Rows , error ) {
315401a := buildArgs (args )
316402r , err := c .Conn .Execute (query , a ... )
@@ -321,8 +407,9 @@ func (c *conn) Query(query string, args []sqldriver.Value) (sqldriver.Rows, erro
321407}
322408
323409func (c * conn ) QueryContext (ctx context.Context , query string , args []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
410+ defer c .watchCtx (ctx )()
324411a := buildNamedArgs (args )
325- r , err := c .Conn .ExecuteContext ( ctx , query , a ... )
412+ r , err := c .Conn .Execute ( query , a ... )
326413if err != nil {
327414return nil , c .state .replyError (err )
328415}
@@ -334,6 +421,10 @@ type stmt struct {
334421connectionState * state
335422}
336423
424+ func (s * stmt ) watchCtx (ctx context.Context ) func () {
425+ return s .connectionState .watchCtx (ctx )
426+ }
427+
337428func (s * stmt ) Close () error {
338429return s .Stmt .Close ()
339430}
@@ -351,6 +442,17 @@ func (s *stmt) Exec(args []sqldriver.Value) (sqldriver.Result, error) {
351442return & result {r }, nil
352443}
353444
445+ func (s * stmt ) ExecContext (ctx context.Context , args []sqldriver.NamedValue ) (sqldriver.Result , error ) {
446+ defer s .watchCtx (ctx )()
447+
448+ a := buildNamedArgs (args )
449+ r , err := s .Stmt .Execute (a ... )
450+ if err != nil {
451+ return nil , s .connectionState .replyError (err )
452+ }
453+ return & result {r }, nil
454+ }
455+
354456func (s * stmt ) Query (args []sqldriver.Value ) (sqldriver.Rows , error ) {
355457a := buildArgs (args )
356458r , err := s .Stmt .Execute (a ... )
@@ -360,6 +462,17 @@ func (s *stmt) Query(args []sqldriver.Value) (sqldriver.Rows, error) {
360462return newRows (r .Resultset )
361463}
362464
465+ func (s * stmt ) QueryContext (ctx context.Context , args []sqldriver.NamedValue ) (sqldriver.Rows , error ) {
466+ defer s .watchCtx (ctx )()
467+
468+ a := buildNamedArgs (args )
469+ r , err := s .Stmt .Execute (a ... )
470+ if err != nil {
471+ return nil , s .connectionState .replyError (err )
472+ }
473+ return newRows (r .Resultset )
474+ }
475+
363476type tx struct {
364477* client.Conn
365478}
0 commit comments