Skip to content
26 changes: 24 additions & 2 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/charset"
)

const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
Expand Down Expand Up @@ -268,8 +269,24 @@ func (c *Conn) writeAuthHandshake() error {
data[11] = 0x00

// Charset [1 byte]
// use default collation id 33 here, is utf-8
data[12] = DEFAULT_COLLATION_ID
// use default collation id 33 here, is `utf8mb3_general_ci`
collationName := c.collation
if len(collationName) == 0 {
collationName = DEFAULT_COLLATION_NAME
}
collation, err := charset.GetCollationByName(collationName)
if err != nil {
return fmt.Errorf("invalid collation name %s", collationName)
}

// the MySQL protocol calls for the collation id to be sent as 1, where only the
// lower 8 bits are used in this field. But wireshark shows that the first by of
// the 23 bytes of filler is used to send the upper 8 bits of the collation id.
// see https://github.com/mysql/mysql-server/pull/541
data[12] = byte(collation.ID & 0xff)
if collation.ID > 255 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be possible always encode the collation.ID to a 2 byte integer, if the right endianness is used. That might simplify the code a bit

data[13] = byte(collation.ID >> 8)
}

// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
Expand All @@ -292,6 +309,11 @@ func (c *Conn) writeAuthHandshake() error {

// Filler [23 bytes] (all 0x00)
pos := 13
if collation.ID > 255 {
// skip setting the first byte of the filler to 0x00 since it is used to
// send the upper 8 bits of the collation id
pos++
}
for ; pos < 13+23; pos++ {
data[pos] = 0
}
Expand Down
73 changes: 73 additions & 0 deletions client/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package client

import (
"github.com/go-mysql-org/go-mysql/packet"
"github.com/pingcap/tidb/pkg/parser/charset"
"net"
"testing"

"github.com/go-mysql-org/go-mysql/mysql"
Expand Down Expand Up @@ -34,3 +37,73 @@ func TestConnGenAttributes(t *testing.T) {
require.Subset(t, data, fixt)
}
}

func TestConnCollation(t *testing.T) {
collations := []string{"big5_chinese_ci",
"utf8_general_ci",
"utf8mb4_0900_ai_ci",
"utf8mb4_de_pb_0900_ai_ci",
"utf8mb4_ja_0900_as_cs",
"utf8mb4_0900_bin",
"utf8mb4_zh_pinyin_tidb_as_cs"}

// test all supported collations by calling writeAuthHandshake() and reading the bytes
// sent to the server to ensure the collation id is set correctly
for _, c := range collations {
collation, err := charset.GetCollationByName(c)
require.NoError(t, err)
server := sendAuthResponse(t, collation.Name)
// read the all the bytes of the handshake response so that client goroutine can complete without blocking
// on the server read.
handShakeResponse := make([]byte, 128)
_, err = server.Read(handShakeResponse)
require.NoError(t, err)

// validate the collation id is set correctly
// if the collation ID is <= 255 the collation ID is stored in the 12th byte
if collation.ID <= 255 {
require.Equal(t, byte(collation.ID), handShakeResponse[12])
// sanity check: validate the 23 bytes of filler with value 0x00 are set correctly
for i := 13; i < 13+23; i++ {
require.Equal(t, byte(0x00), handShakeResponse[i])
}
} else {
// if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes
require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12])
require.Equal(t, byte(collation.ID>>8), handShakeResponse[13])

// sanity check: validate the 22 bytes of filler with value 0x00 are set correctly
for i := 14; i < 14+22; i++ {
require.Equal(t, byte(0x00), handShakeResponse[i])
}
}

// and finally the username
password := string(handShakeResponse[36:40])
require.Equal(t, "test", password)

require.NoError(t, server.Close())
}
}

func sendAuthResponse(t *testing.T, collation string) net.Conn {
server, client := net.Pipe()
c := &Conn{
Conn: &packet.Conn{
Conn: client,
},
authPluginName: "mysql_native_password",
user: "test",
db: "test",
password: "test",
proto: "tcp",
collation: collation,
salt: ([]byte)("123456781234567812345678"),
}

go func() {
err := c.writeAuthHandshake()
require.NoError(t, err)
}()
return server
}
21 changes: 20 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) {
func (s *clientTestSuite) SetupSuite() {
var err error
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
s.c, err = Connect(addr, *testUser, *testPassword, "")
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic, but this is essentially a no-op since
// the collation set is the default value
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
})
require.NoError(s.T(), err)

var result *mysql.Result
Expand Down Expand Up @@ -228,6 +232,21 @@ func (s *clientTestSuite) TestConn_SetCharset() {
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
err := s.c.SetCollation("latin1_swedish_ci")
require.Error(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCollation() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
// test the collation logic
_ = conn.SetCollation("invalid_collation")
})

require.Error(s.T(), err)
}

func (s *clientTestSuite) testStmt_DropTable() {
str := `drop table if exists mixer_test_stmt`

Expand Down
24 changes: 22 additions & 2 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Conn struct {
status uint16

charset string
// sets the collation to be set on the auth handshake, this does not issue a 'set names' command
collation string

salt []byte
authPluginName string
Expand Down Expand Up @@ -67,15 +69,19 @@ func Connect(addr string, user string, password string, dbName string, options .
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

dialer := &net.Dialer{}
return ConnectWithContext(ctx, addr, user, password, dbName, options...)
}

// ConnectWithContext to a MySQL addr using the provided context.
func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
dialer := &net.Dialer{}
return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...)
}

// Dialer connects to the address on the named network using the provided context.
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)

// Connect to a MySQL server using the given Dialer.
// ConnectWithDialer to a MySQL server using the given Dialer.
func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
c := new(Conn)

Expand Down Expand Up @@ -357,6 +363,20 @@ func (c *Conn) SetCharset(charset string) error {
}
}

func (c *Conn) SetCollation(collation string) error {
if c.status == 0 {
c.collation = collation
} else {
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
}

return nil
}

func (c *Conn) GetCollation() string {
return c.collation
}

func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
return nil, errors.Trace(err)
Expand Down