Forráskód Böngészése

Implemented handshake response packet.

Josh Brickner 7 éve
szülő
commit
ffae52d0bf
3 módosított fájl, 245 hozzáadás és 19 törlés
  1. 17 10
      binlog/connection.go
  2. 227 8
      binlog/handshake.go
  3. 1 1
      main.go

+ 17 - 10
binlog/connection.go

@@ -12,6 +12,7 @@ import (
 )
 
 const NullByte byte = '\x00'
+const MaxPacketSize = 16777216
 
 // MySQL Packet Data Types
 const TypeNullTerminatedString = int(0)
@@ -19,11 +20,14 @@ const TypeFixedString = int(1)
 const TypeFixedInt = int(2)
 
 type Config struct {
-	Host    string `json:"host"`
-	Port    int    `json:"port"`
-	User    string `json:"user"`
-	Pass    string `json:"password"`
-	Timeout time.Duration
+	Host       string `json:"host"`
+	Port       int    `json:"port"`
+	User       string `json:"user"`
+	Pass       string `json:"password"`
+	Database   string `json:"database"`
+	SSL        bool   `json:"ssl"`
+	VerifyCert bool   `json:"verify_cert"`
+	Timeout    time.Duration
 }
 
 func newBinlogConfig(dsn string) (*Config, error) {
@@ -36,8 +40,8 @@ func newBinlogConfig(dsn string) (*Config, error) {
 }
 
 type Conn struct {
-	Config  *Config
-	tcpConn *net.TCPConn
+	Config    *Config
+	tcpConn   *net.TCPConn
 	Handshake *HandshakePacket
 }
 
@@ -59,7 +63,7 @@ func (c Conn) Begin() (driver.Tx, error) {
 	return nil, nil
 }
 
-type Driver struct {}
+type Driver struct{}
 
 func (d Driver) Open(dsn string) (driver.Conn, error) {
 	config, err := newBinlogConfig(dsn)
@@ -69,7 +73,7 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 
 	blConn := newBinlogConn(config)
 
-	dialer := net.Dialer{Timeout: blConn.Config.Timeout,}
+	dialer := net.Dialer{Timeout: blConn.Config.Timeout}
 	addr := fmt.Sprintf("%s:%d", blConn.Config.Host, blConn.Config.Port)
 	c, err := dialer.Dial("tcp", addr)
 	blConn.tcpConn = c.(*net.TCPConn)
@@ -85,7 +89,10 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 	hsp, err := blConn.handshakePacket()
 	blConn.Handshake = hsp
 
-	fmt.Printf("%+v", blConn.Handshake)
+	resp := blConn.handshakeResponse()
+	b := resp.encode()
+	fmt.Printf("%d", b)
+	_, err = blConn.tcpConn.Write(b)
 
 	return blConn, err
 }

+ 227 - 8
binlog/handshake.go

@@ -1,5 +1,10 @@
 package binlog
 
+import (
+	"bytes"
+	"encoding/binary"
+)
+
 type HandshakePacket struct {
 	PacketLength         uint64
 	SequenceId           uint64
@@ -38,6 +43,7 @@ type CapabilityFlags struct {
 	MultiStatements            bool
 	MultiResults               bool
 	PsMultiResults             bool
+	PluginAuth                 bool
 	ConnectAttrs               bool
 	PluginAuthLenEncClientData bool
 	CanHandleExpiredPasswords  bool
@@ -86,14 +92,15 @@ func (hs *HandshakePacket) decodeCapabilityFlags() {
 		MultiStatements:            (hs.CapabilityFlags2[0] & 1) > 0,
 		MultiResults:               (hs.CapabilityFlags2[0] & 2) > 0,
 		PsMultiResults:             (hs.CapabilityFlags2[0] & 4) > 0,
-		ConnectAttrs:               (hs.CapabilityFlags2[0] & 8) > 0,
-		PluginAuthLenEncClientData: (hs.CapabilityFlags2[0] & 16) > 0,
-		CanHandleExpiredPasswords:  (hs.CapabilityFlags2[0] & 32) > 0,
-		SessionTrack:               (hs.CapabilityFlags2[0] & 64) > 0,
-		DeprecateEOF:               (hs.CapabilityFlags2[1] & 128) > 0,
-		SslVerifyServerCert:        (hs.CapabilityFlags2[1] & 1) > 0,
-		OptionalResultSetMetadata:  (hs.CapabilityFlags2[1] & 2) > 0,
-		RememberOptions:            (hs.CapabilityFlags2[1] & 4) > 0,
+		PluginAuth:                 (hs.CapabilityFlags2[0] & 8) > 0,
+		ConnectAttrs:               (hs.CapabilityFlags2[0] & 16) > 0,
+		PluginAuthLenEncClientData: (hs.CapabilityFlags2[0] & 32) > 0,
+		CanHandleExpiredPasswords:  (hs.CapabilityFlags2[0] & 64) > 0,
+		SessionTrack:               (hs.CapabilityFlags2[0] & 128) > 0,
+		DeprecateEOF:               (hs.CapabilityFlags2[1] & 1) > 0,
+		SslVerifyServerCert:        (hs.CapabilityFlags2[1] & 2) > 0,
+		OptionalResultSetMetadata:  (hs.CapabilityFlags2[1] & 4) > 0,
+		RememberOptions:            (hs.CapabilityFlags2[1] & 8) > 0,
 	}
 }
 
@@ -201,3 +208,215 @@ func (c *Conn) handshakePacket() (*HandshakePacket, error) {
 
 	return &packet, nil
 }
+
+type HandshakeResponse struct {
+	ClientFlag         *CapabilityFlags
+	MaxPacketSize      uint64
+	CharacterSet       uint64
+	Username           string
+	AuthResponseLength uint64
+	AuthResponse       string
+	Database           string
+	ClientPluginName   string
+	KeyValues          map[string]string
+}
+
+func (hr *HandshakeResponse) encode() []byte {
+	buf := bytes.NewBuffer(make([]byte, 0))
+
+	// Capabilities flag.
+	flags := make([]byte, 4)
+	if hr.ClientFlag.LongPassword {
+		flags[0] |= 0x1
+	}
+
+	if hr.ClientFlag.FoundRows {
+		flags[0] |= 0x2
+	}
+
+	if hr.ClientFlag.LongFlag {
+		flags[0] |= 0x4
+	}
+
+	if hr.ClientFlag.ConnectWithDb {
+		flags[0] |= 0x8
+	}
+
+	if hr.ClientFlag.NoSchema {
+		flags[0] |= 0x10
+	}
+
+	if hr.ClientFlag.Compress {
+		flags[0] |= 0x20
+	}
+
+	if hr.ClientFlag.Odbc {
+		flags[0] |= 0x40
+	}
+
+	if hr.ClientFlag.LocalFiles {
+		flags[0] |= 0x80
+	}
+
+	if hr.ClientFlag.IgnoreSpace {
+		flags[1] |= 0x1
+	}
+
+	if hr.ClientFlag.Protocol41 {
+		flags[1] |= 0x2
+	}
+
+	if hr.ClientFlag.Interactive {
+		flags[1] |= 0x4
+	}
+
+	if hr.ClientFlag.Ssl {
+		flags[1] |= 0x8
+	}
+
+	if hr.ClientFlag.IgnoreSigpipe {
+		flags[1] |= 0x10
+	}
+
+	if hr.ClientFlag.Transactions {
+		flags[1] |= 0x20
+	}
+
+	if hr.ClientFlag.Reserved {
+		flags[1] |= 0x40
+	}
+
+	if hr.ClientFlag.Reserved2 {
+		flags[1] |= 0x80
+	}
+
+	if hr.ClientFlag.MultiStatements {
+		flags[2] |= 0x1
+	}
+
+	if hr.ClientFlag.MultiResults {
+		flags[2] |= 0x2
+	}
+
+	if hr.ClientFlag.PsMultiResults {
+		flags[2] |= 0x4
+	}
+
+	if hr.ClientFlag.PluginAuth {
+		flags[2] |= 0x8
+	}
+
+	if hr.ClientFlag.ConnectAttrs {
+		flags[2] |= 0x10
+	}
+
+	if hr.ClientFlag.PluginAuthLenEncClientData {
+		flags[2] |= 0x20
+	}
+
+	if hr.ClientFlag.CanHandleExpiredPasswords {
+		flags[2] |= 0x40
+	}
+
+	if hr.ClientFlag.SessionTrack {
+		flags[2] |= 0x80
+	}
+
+	if hr.ClientFlag.DeprecateEOF {
+		flags[3] |= 0x1
+	}
+
+	if hr.ClientFlag.SslVerifyServerCert {
+		flags[3] |= 0x2
+	}
+
+	if hr.ClientFlag.OptionalResultSetMetadata {
+		flags[3] |= 0x4
+	}
+
+	if hr.ClientFlag.RememberOptions {
+		flags[3] |= 0x8
+	}
+
+	// Write Capability Flags.
+	buf.Write(flags)
+
+	// Write MaxPacketSize
+	mps := make([]byte, 4)
+	binary.LittleEndian.PutUint32(mps, uint32(MaxPacketSize))
+	buf.Write(mps)
+
+	// Write CharacterSet
+	cs := make([]byte, 2)
+	binary.LittleEndian.PutUint16(cs, uint16(hr.CharacterSet))
+	buf.Write(cs[:1])
+
+	// Write Filler
+	buf.Write(make([]byte, 23))
+
+	// Write username
+	u := append([]byte(hr.Username), NullByte)
+	buf.Write(u)
+
+	if hr.ClientFlag.PluginAuth && hr.AuthResponseLength > 0 {
+		pal := make([]byte, 2)
+		binary.LittleEndian.PutUint16(pal, uint16(hr.AuthResponseLength))
+		buf.Write(pal[:1])
+		buf.Write([]byte(hr.AuthResponse))
+	}
+
+	if hr.ClientFlag.ConnectWithDb {
+		buf.Write(append([]byte(hr.Database), NullByte))
+	}
+
+	pl := make([]byte, 4)
+	binary.LittleEndian.PutUint32(pl, uint32(buf.Len()))
+	p := append(pl[:3], 1)
+	p = append(p, buf.Bytes()...)
+	buf = bytes.NewBuffer(p)
+
+	return buf.Bytes()
+}
+
+func (c *Conn) handshakeResponse() *HandshakeResponse {
+	return &HandshakeResponse{
+		ClientFlag: &CapabilityFlags{
+			LongPassword:               true,
+			FoundRows:                  true,
+			LongFlag:                   true,
+			ConnectWithDb:              true,
+			NoSchema:                   false,
+			Compress:                   true,
+			Odbc:                       false,
+			LocalFiles:                 false,
+			IgnoreSpace:                true,
+			Protocol41:                 true,
+			Interactive:                true,
+			Ssl:                        c.Config.SSL,
+			IgnoreSigpipe:              false,
+			Transactions:               true,
+			Reserved:                   false,
+			Reserved2:                  false,
+			MultiStatements:            false,
+			MultiResults:               false,
+			PsMultiResults:             true,
+			PluginAuth:                 true,
+			ConnectAttrs:               false,
+			PluginAuthLenEncClientData: true,
+			CanHandleExpiredPasswords:  false,
+			SessionTrack:               true,
+			DeprecateEOF:               true,
+			SslVerifyServerCert:        c.Config.VerifyCert,
+			OptionalResultSetMetadata:  true,
+			RememberOptions:            true,
+		},
+		MaxPacketSize:      MaxPacketSize,
+		CharacterSet:       45,
+		Username:           c.Config.User,
+		AuthResponseLength: 5,
+		AuthResponse:       "Hello",
+		Database:           c.Config.Database,
+		ClientPluginName:   c.Handshake.AuthPluginName,
+		KeyValues:          nil,
+	}
+}

+ 1 - 1
main.go

@@ -7,7 +7,7 @@ import (
 )
 
 func main() {
-	conn, err := sql.Open("mysql-binlog", "{\"host\": \"127.0.0.1\", \"port\": 3306, \"user\": \"root\", \"password\": \"root\"}")
+	conn, err := sql.Open("mysql-binlog", "{\"host\": \"127.0.0.1\", \"port\": 3306, \"user\": \"root\", \"password\": \"root\", \"database\": \"test\", \"ssl\": false}")
 	if err != nil {
 		fmt.Printf("Open Error: %+v\n", err)
 	}