- Notifications
You must be signed in to change notification settings - Fork 1k
client,mysql: Add support for Query Attributes #976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
740b4c6 1390b58 c833ce4 f562050 4cbdd13 deb4492 2a29ba0 6f41c36 5ad9e06 3143ff8 a4f390b 2d288d4 5b4d37f f6fda23 9e75208 7cad496 d330e39 25f3bae 9865fd5 e1b8e65 16b678e e939d95 622b0ba 7b93b16 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -57,6 +57,11 @@ type Conn struct { | |
| authPluginName string | ||
| | ||
| connectionID uint32 | ||
| | ||
| queryAttributes []QueryAttribute | ||
| | ||
| // Include the file + line as query attribute. The number set which frame in the stack should be used. | ||
| includeLine int | ||
| } | ||
| | ||
| // This function will be called for every row in resultset from ExecuteSelectStreaming. | ||
| | @@ -100,6 +105,7 @@ type Dialer func(ctx context.Context, network, address string) (net.Conn, error) | |
| func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbName string, dialer Dialer, options ...Option) (*Conn, error) { | ||
| c := new(Conn) | ||
| | ||
| c.includeLine = -1 | ||
| c.BufferSize = defaultBufferSize | ||
| c.attributes = map[string]string{ | ||
| "_client_name": "go-mysql", | ||
| | @@ -310,7 +316,7 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { | |
| // flag set to signal the server multiple queries are executed. Handling the responses | ||
| // is up to the implementation of perResultCallback. | ||
| func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) { | ||
| if err := c.writeCommandStr(COM_QUERY, query); err != nil { | ||
| if err := c.execSend(query); err != nil { | ||
| return nil, errors.Trace(err) | ||
| } | ||
| | ||
| | @@ -363,7 +369,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall | |
| // | ||
| // ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving. | ||
| func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { | ||
| if err := c.writeCommandStr(COM_QUERY, command); err != nil { | ||
| if err := c.execSend(command); err != nil { | ||
| return errors.Trace(err) | ||
| } | ||
| | ||
| | @@ -489,14 +495,64 @@ func (c *Conn) ReadOKPacket() (*Result, error) { | |
| return c.readOK() | ||
| } | ||
| | ||
| // Send COM_QUERY and read the result | ||
| func (c *Conn) exec(query string) (*Result, error) { | ||
| if err := c.writeCommandStr(COM_QUERY, query); err != nil { | ||
| err := c.execSend(query) | ||
| if err != nil { | ||
| return nil, errors.Trace(err) | ||
| } | ||
| | ||
| return c.readResult(false) | ||
| } | ||
| | ||
| // Sends COM_QUERY | ||
| // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query.html | ||
| func (c *Conn) execSend(query string) error { | ||
| var buf bytes.Buffer | ||
| defer clear(c.queryAttributes) | ||
| | ||
| if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { | ||
| if c.includeLine >= 0 { | ||
| _, file, line, ok := runtime.Caller(c.includeLine) | ||
| if ok { | ||
| lineAttr := QueryAttribute{ | ||
| Name: "_line", | ||
| Value: fmt.Sprintf("%s:%d", file, line), | ||
| } | ||
| c.queryAttributes = append(c.queryAttributes, lineAttr) | ||
| } | ||
| } | ||
| | ||
| numParams := len(c.queryAttributes) | ||
| buf.Write(PutLengthEncodedInt(uint64(numParams))) | ||
| buf.WriteByte(0x1) // parameter_set_count, unused | ||
| if numParams > 0 { | ||
| // null_bitmap, length: (num_params+7)/8 | ||
| for i := 0; i < (numParams+7)/8; i++ { | ||
| buf.WriteByte(0x0) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we allow NULL query attributes so the bitmap will have | ||
| } | ||
| buf.WriteByte(0x1) // new_params_bind_flag, unused | ||
| for _, qa := range c.queryAttributes { | ||
| buf.Write(qa.TypeAndFlag()) | ||
| buf.Write(PutLengthEncodedString([]byte(qa.Name))) | ||
| } | ||
| for _, qa := range c.queryAttributes { | ||
| buf.Write(qa.ValueBytes()) | ||
| } | ||
| } | ||
| } | ||
| | ||
| _, err := buf.Write(utils.StringToByteSlice(query)) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| | ||
| if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil { | ||
| return errors.Trace(err) | ||
| } | ||
| | ||
| return nil | ||
| } | ||
| | ||
| // CapabilityString is returning a string with the names of capability flags | ||
| // separated by "|". Examples of capability names are CLIENT_DEPRECATE_EOF and CLIENT_PROTOCOL_41. | ||
| // These are defined as constants in the mysql package. | ||
| | @@ -627,3 +683,17 @@ func (c *Conn) StatusString() string { | |
| | ||
| return strings.Join(stats, "|") | ||
| } | ||
| | ||
| // SetQueryAttributes sets the query attributes to be send along with the next query | ||
| func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { | ||
| c.queryAttributes = attrs | ||
| return nil | ||
| } | ||
| | ||
| // IncludeLine can be passed as option when connecting to include the file name and line number | ||
| // of the caller as query attribute `_line` when sending queries. | ||
| // The argument is used the dept in the stack. The top level is go-mysql and then there are the | ||
| // levels of the application. | ||
| func (c *Conn) IncludeLine(frame int) { | ||
| c.includeLine = frame | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| | @@ -5,6 +5,7 @@ import ( | |||||
| "encoding/json" | ||||||
| "fmt" | ||||||
| "math" | ||||||
| "runtime" | ||||||
| | ||||||
| . "github.com/go-mysql-org/go-mysql/mysql" | ||||||
| "github.com/go-mysql-org/go-mysql/utils" | ||||||
| | @@ -56,18 +57,34 @@ func (s *Stmt) Close() error { | |||||
| return nil | ||||||
| } | ||||||
| | ||||||
| // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html | ||||||
| func (s *Stmt) write(args ...interface{}) error { | ||||||
| defer clear(s.conn.queryAttributes) | ||||||
| paramsNum := s.params | ||||||
| | ||||||
| if len(args) != paramsNum { | ||||||
| return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) | ||||||
| } | ||||||
| | ||||||
| paramTypes := make([]byte, paramsNum<<1) | ||||||
| paramValues := make([][]byte, paramsNum) | ||||||
| if (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) { | ||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggested change
the old code has clear priority. It's OK to keep it. | ||||||
| _, file, line, ok := runtime.Caller(s.conn.includeLine) | ||||||
| if ok { | ||||||
| lineAttr := QueryAttribute{ | ||||||
| Name: "_line", | ||||||
| Value: fmt.Sprintf("%s:%d", file, line), | ||||||
| } | ||||||
| s.conn.queryAttributes = append(s.conn.queryAttributes, lineAttr) | ||||||
| } | ||||||
| } | ||||||
| | ||||||
| qaLen := len(s.conn.queryAttributes) | ||||||
| paramTypes := make([][]byte, paramsNum+qaLen) | ||||||
| paramFlags := make([][]byte, paramsNum+qaLen) | ||||||
| paramValues := make([][]byte, paramsNum+qaLen) | ||||||
| paramNames := make([][]byte, paramsNum+qaLen) | ||||||
| | ||||||
| //NULL-bitmap, length: (num-params+7) | ||||||
| nullBitmap := make([]byte, (paramsNum+7)>>3) | ||||||
| nullBitmap := make([]byte, (paramsNum+qaLen+7)>>3) | ||||||
| | ||||||
| length := 1 + 4 + 1 + 4 + ((paramsNum + 7) >> 3) + 1 + (paramsNum << 1) | ||||||
| | ||||||
| | @@ -76,76 +93,89 @@ func (s *Stmt) write(args ...interface{}) error { | |||||
| for i := range args { | ||||||
| if args[i] == nil { | ||||||
| nullBitmap[i/8] |= 1 << (uint(i) % 8) | ||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this part is a bit complicated, I'll review later 😂 | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_NULL | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_NULL} | ||||||
| paramNames[i] = []byte{0} // length encoded, no name | ||||||
| paramFlags[i] = []byte{0} | ||||||
| continue | ||||||
| } | ||||||
| | ||||||
| newParamBoundFlag = 1 | ||||||
| | ||||||
| switch v := args[i].(type) { | ||||||
| case int8: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_TINY | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_TINY} | ||||||
| paramValues[i] = []byte{byte(v)} | ||||||
| case int16: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_SHORT | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_SHORT} | ||||||
| paramValues[i] = Uint16ToBytes(uint16(v)) | ||||||
| case int32: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_LONG | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_LONG} | ||||||
| paramValues[i] = Uint32ToBytes(uint32(v)) | ||||||
| case int: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||||||
| paramValues[i] = Uint64ToBytes(uint64(v)) | ||||||
| case int64: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||||||
| paramValues[i] = Uint64ToBytes(uint64(v)) | ||||||
| case uint8: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_TINY | ||||||
| paramTypes[(i<<1)+1] = 0x80 | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_TINY} | ||||||
| paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
| paramValues[i] = []byte{v} | ||||||
| case uint16: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_SHORT | ||||||
| paramTypes[(i<<1)+1] = 0x80 | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_SHORT} | ||||||
| paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
| paramValues[i] = Uint16ToBytes(v) | ||||||
| case uint32: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_LONG | ||||||
| paramTypes[(i<<1)+1] = 0x80 | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_LONG} | ||||||
| paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
| paramValues[i] = Uint32ToBytes(v) | ||||||
| case uint: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||||||
| paramTypes[(i<<1)+1] = 0x80 | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||||||
| paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
| paramValues[i] = Uint64ToBytes(uint64(v)) | ||||||
| case uint64: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_LONGLONG | ||||||
| paramTypes[(i<<1)+1] = 0x80 | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} | ||||||
| paramFlags[i] = []byte{PARAM_UNSIGNED} | ||||||
| paramValues[i] = Uint64ToBytes(v) | ||||||
| case bool: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_TINY | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_TINY} | ||||||
| if v { | ||||||
| paramValues[i] = []byte{1} | ||||||
| } else { | ||||||
| paramValues[i] = []byte{0} | ||||||
| } | ||||||
| case float32: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_FLOAT | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_FLOAT} | ||||||
| paramValues[i] = Uint32ToBytes(math.Float32bits(v)) | ||||||
| case float64: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_DOUBLE | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_DOUBLE} | ||||||
| paramValues[i] = Uint64ToBytes(math.Float64bits(v)) | ||||||
| case string: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_STRING | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_STRING} | ||||||
| paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) | ||||||
| case []byte: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_STRING | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_STRING} | ||||||
| paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) | ||||||
| case json.RawMessage: | ||||||
| paramTypes[i<<1] = MYSQL_TYPE_STRING | ||||||
| paramTypes[i] = []byte{MYSQL_TYPE_STRING} | ||||||
| paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) | ||||||
| default: | ||||||
| return fmt.Errorf("invalid argument type %T", args[i]) | ||||||
| } | ||||||
| paramNames[i] = []byte{0} // length encoded, no name | ||||||
| if paramFlags[i] == nil { | ||||||
| paramFlags[i] = []byte{0} | ||||||
| } | ||||||
| | ||||||
| length += len(paramValues[i]) | ||||||
| } | ||||||
| for i, qa := range s.conn.queryAttributes { | ||||||
| tf := qa.TypeAndFlag() | ||||||
| paramTypes[(i + paramsNum)] = []byte{tf[0]} | ||||||
| paramFlags[i+paramsNum] = []byte{tf[1]} | ||||||
| paramValues[i+paramsNum] = qa.ValueBytes() | ||||||
| paramNames[i+paramsNum] = PutLengthEncodedString([]byte(qa.Name)) | ||||||
| } | ||||||
| | ||||||
| data := utils.BytesBufferGet() | ||||||
| defer func() { | ||||||
| | @@ -159,25 +189,40 @@ func (s *Stmt) write(args ...interface{}) error { | |||||
| data.WriteByte(COM_STMT_EXECUTE) | ||||||
| data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)}) | ||||||
| | ||||||
| //flag: CURSOR_TYPE_NO_CURSOR | ||||||
| data.WriteByte(0x00) | ||||||
| flags := CURSOR_TYPE_NO_CURSOR | ||||||
| if paramsNum > 0 { | ||||||
| flags |= PARAMETER_COUNT_AVAILABLE | ||||||
| } | ||||||
| data.WriteByte(flags) | ||||||
| | ||||||
| //iteration-count, always 1 | ||||||
| data.Write([]byte{1, 0, 0, 0}) | ||||||
| | ||||||
| if s.params > 0 { | ||||||
| data.Write(nullBitmap) | ||||||
| | ||||||
| //new-params-bound-flag | ||||||
| data.WriteByte(newParamBoundFlag) | ||||||
| | ||||||
| if newParamBoundFlag == 1 { | ||||||
| //type of each parameter, length: num-params * 2 | ||||||
| data.Write(paramTypes) | ||||||
| | ||||||
| //value of each parameter | ||||||
| for _, v := range paramValues { | ||||||
| data.Write(v) | ||||||
| if paramsNum > 0 || (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && (flags&PARAMETER_COUNT_AVAILABLE > 0)) { | ||||||
| if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { | ||||||
| paramsNum += len(s.conn.queryAttributes) | ||||||
| data.Write(PutLengthEncodedInt(uint64(paramsNum))) | ||||||
| } | ||||||
| if paramsNum > 0 { | ||||||
| data.Write(nullBitmap) | ||||||
| | ||||||
| //new-params-bound-flag | ||||||
| data.WriteByte(newParamBoundFlag) | ||||||
| | ||||||
| if newParamBoundFlag == 1 { | ||||||
| for i := 0; i < paramsNum; i++ { | ||||||
| data.Write(paramTypes[i]) | ||||||
| data.Write(paramFlags[i]) | ||||||
| | ||||||
| if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { | ||||||
| data.Write(paramNames[i]) | ||||||
| } | ||||||
| } | ||||||
| | ||||||
| //value of each parameter | ||||||
| for _, v := range paramValues { | ||||||
| data.Write(v) | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| | ||||||
Uh oh!
There was an error while loading. Please reload this page.