Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package schema

import (
"database/sql"
"fmt"
"strings"

Expand Down Expand Up @@ -161,6 +162,25 @@ func IsTableExist(conn mysql.Executer, schema string, name string) (bool, error)
return r.RowNumber() == 1, nil
}

func NewTableFromSqlDB(conn *sql.DB, schema string, name string) (*Table, error) {
ta := &Table{
Schema: schema,
Name: name,
Columns: make([]TableColumn, 0, 16),
Indexes: make([]*Index, 0, 8),
}

if err := ta.fetchColumnsViaSqlDB(conn); err != nil {
return nil, errors.Trace(err)
}

if err := ta.fetchIndexesViaSqlDB(conn); err != nil {
return nil, errors.Trace(err)
}

return ta, nil
}

func NewTable(conn mysql.Executer, schema string, name string) (*Table, error) {
ta := &Table{
Schema: schema,
Expand Down Expand Up @@ -197,6 +217,30 @@ func (ta *Table) fetchColumns(conn mysql.Executer) error {
return nil
}

func (ta *Table) fetchColumnsViaSqlDB(conn *sql.DB) error {
r, err := conn.Query(fmt.Sprintf("describe `%s`.`%s`", ta.Schema, ta.Name))
if err != nil {
return errors.Trace(err)
}

defer r.Close()

var unusedVal interface{}
unused := &unusedVal

for r.Next() {
var name, colType, extra string
err := r.Scan(&name, &colType, &unused, &unused, &unused, &extra)
if err != nil {
return errors.Trace(err)
}

ta.AddColumn(name, colType, extra)
}

return r.Err()
}

func (ta *Table) fetchIndexes(conn mysql.Executer) error {
r, err := conn.Execute(fmt.Sprintf("show index from `%s`.`%s`", ta.Schema, ta.Name))
if err != nil {
Expand All @@ -216,6 +260,59 @@ func (ta *Table) fetchIndexes(conn mysql.Executer) error {
currentIndex.AddColumn(colName, cardinality)
}

return ta.fetchPrimaryKeyColumns()

}

func (ta *Table) fetchIndexesViaSqlDB(conn *sql.DB) error {
r, err := conn.Query(fmt.Sprintf("show index from `%s`.`%s`", ta.Schema, ta.Name))
if err != nil {
return errors.Trace(err)
}

defer r.Close()

var currentIndex *Index
currentName := ""

var unusedVal interface{}
unused := &unusedVal

for r.Next() {
var indexName, colName string
var cardinality uint64

err := r.Scan(
&unused,
&unused,
&indexName,
&unused,
&colName,
&unused,
&cardinality,
&unused,
&unused,
&unused,
&unused,
&unused,
&unused,
)
if err != nil {
return errors.Trace(err)
}

if currentName != indexName {
currentIndex = ta.AddIndex(indexName)
currentName = indexName
}

currentIndex.AddColumn(colName, cardinality)
}

return ta.fetchPrimaryKeyColumns()
}

func (ta *Table) fetchPrimaryKeyColumns() error {
if len(ta.Indexes) == 0 {
return nil
}
Expand Down
17 changes: 16 additions & 1 deletion schema/schema_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package schema

import (
"database/sql"
"flag"
"fmt"
"testing"

. "github.com/pingcap/check"
"github.com/siddontang/go-mysql/client"
_ "github.com/siddontang/go-mysql/driver"
)

// use docker mysql for test
Expand All @@ -17,7 +19,8 @@ func Test(t *testing.T) {
}

type schemaTestSuite struct {
conn *client.Conn
conn *client.Conn
sqlDB *sql.DB
}

var _ = Suite(&schemaTestSuite{})
Expand All @@ -26,12 +29,19 @@ func (s *schemaTestSuite) SetUpSuite(c *C) {
var err error
s.conn, err = client.Connect(fmt.Sprintf("%s:%d", *host, 3306), "root", "", "test")
c.Assert(err, IsNil)

s.sqlDB, err = sql.Open("mysql", fmt.Sprintf("root:@%s:3306", *host))
c.Assert(err, IsNil)
}

func (s *schemaTestSuite) TearDownSuite(c *C) {
if s.conn != nil {
s.conn.Close()
}

if s.sqlDB != nil {
s.sqlDB.Close()
}
}

func (s *schemaTestSuite) TestSchema(c *C) {
Expand Down Expand Up @@ -74,6 +84,11 @@ func (s *schemaTestSuite) TestSchema(c *C) {
c.Assert(ta.Columns[0].IsUnsigned, IsFalse)
c.Assert(ta.Columns[8].IsUnsigned, IsTrue)
c.Assert(ta.Columns[9].IsUnsigned, IsTrue)

taSqlDb, err := NewTableFromSqlDB(s.sqlDB, "test", "schema_test")
c.Assert(err, IsNil)

c.Assert(taSqlDb, DeepEquals, ta)
}

func (s *schemaTestSuite) TestQuoteSchema(c *C) {
Expand Down