Selaa lähdekoodia

Corrected implementation of fix integer decoding

Josh Brickner 7 vuotta sitten
vanhempi
sitoutus
2a899cbcfb
3 muutettua tiedostoa jossa 91 lisäystä ja 16 poistoa
  1. 47 11
      binlog/authentication.go
  2. 39 4
      binlog/connection.go
  3. 5 1
      binlog/handshake.go

+ 47 - 11
binlog/authentication.go

@@ -1,25 +1,54 @@
 package binlog
 
 import (
+	"bytes"
 	"crypto/sha1"
 	"crypto/sha256"
+	"fmt"
 )
 
-func (c *Conn) authenticate(hr *HandshakeResponse) {
-	switch c.Handshake.AuthPluginName {
-	case "mysql_native_password":
-		c.doSha1Auth(hr)
-	case "caching_sha2_password":
-		c.doSha2Auth(hr)
+type AuthResponse struct {
+	PacketLength   uint64
+	SequenceID     uint64
+	Status         uint64
+	PluginName     string
+	AuthPluginData *bytes.Buffer
+}
+
+func (c *Conn) decodeAuthResponsePacket() (*AuthResponse, error) {
+	packet := AuthResponse{}
+
+	packet.PacketLength = c.getInt(TypeFixedInt, 3)
+	packet.SequenceID = c.getInt(TypeFixedInt, 1)
+	packet.Status = c.getInt(TypeFixedInt, 1)
+	packet.PluginName = c.getString(TypeNullTerminatedString, 0)
+	packet.AuthPluginData = c.readBytes(20)
+
+	err := c.scanner.Err()
+	if err != nil {
+		return nil, err
 	}
+
+	return &packet, err
 }
 
-func (c *Conn) doSha1Auth(hr *HandshakeResponse) {
+func (c *Conn) writeAuthSwitchPacket() {
+
 }
 
-func (c *Conn) doSha2Auth(hr *HandshakeResponse) {
+func (c *Conn) authenticate(hr *HandshakeResponse) {
+	var ar []byte
 	salt := append(c.Handshake.AuthPluginDataPart1.Bytes(), c.Handshake.AuthPluginDataPart2.Bytes()...)
-	ar := c.cachingSha2Auth(salt, []byte(hr.AuthResponse))
+	password := []byte(hr.AuthResponse)
+	fmt.Println(hr.AuthResponse)
+
+	switch c.Handshake.AuthPluginName {
+	case "mysql_native_password":
+		ar = c.nativeSha1Auth(salt, password)
+	case "caching_sha2_password":
+		ar = c.cachingSha2Auth(salt, password)
+	}
+
 	hr.AuthResponseLength = uint64(len(ar))
 	if hr.ClientFlag.PluginAuthLenEncClientData {
 		c.putInt(TypeLenEncInt, hr.AuthResponseLength, 0)
@@ -32,9 +61,16 @@ func (c *Conn) doSha2Auth(hr *HandshakeResponse) {
 	}
 }
 
-func (c *Conn) nativeSha1Auth() {
-	// SHA1(password) XOR SHA1("20-bytes random data from server" <concat> SHA1(SHA1(password)))
+func (c *Conn) nativeSha1Auth(salt []byte, password []byte) []byte {
+	pHash := c.sha1Hash(password)
+	pHashHash := c.sha1Hash(pHash)
+	spHash := c.sha1Hash(append(salt, pHashHash...))
 
+	for i := range pHash {
+		pHash[i] ^= spHash[i]
+	}
+
+	return pHash
 }
 
 func (c *Conn) cachingSha2Auth(salt []byte, password []byte) []byte {

+ 39 - 4
binlog/connection.go

@@ -110,14 +110,15 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 	dialer := net.Dialer{Timeout: c.Config.Timeout}
 	addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
 	t, err := dialer.Dial("tcp", addr)
-	c.tcpConn = t.(*net.TCPConn)
 
 	if err != nil {
 		netErr, ok := err.(net.Error)
-		if ok && netErr.Temporary() {
+		if ok && !netErr.Temporary() {
 			fmt.Printf("Error: %s", netErr.Error())
 			return nil, err
 		}
+	} else {
+		c.tcpConn = t.(*net.TCPConn)
 	}
 
 	err = c.decodeHandshakePacket()
@@ -130,7 +131,12 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 		return nil, err
 	}
 
-	fmt.Printf("%+v\n", c.Handshake)
+	packet, err := c.decodeAuthResponsePacket()
+	if err != nil {
+		return nil, err
+	}
+
+	fmt.Printf("%+v\n", packet)
 	return c, err
 }
 
@@ -222,10 +228,39 @@ func (c *Conn) decFixedString(l uint64) string {
 func (c *Conn) decFixedInt(l uint64) uint64 {
 	var i uint64
 	b := c.readBytes(l)
-	i, _ = binary.ReadUvarint(b)
+	if l <= 2 {
+		var x uint16
+		pb := c.padBytes(2, b.Bytes())
+		br := bytes.NewReader(pb)
+		_ = binary.Read(br, binary.LittleEndian, &x)
+		i = uint64(x)
+	} else if l <= 4 {
+		var x uint32
+		pb := c.padBytes(4, b.Bytes())
+		br := bytes.NewReader(pb)
+		_ = binary.Read(br, binary.LittleEndian, &x)
+		i = uint64(x)
+	} else if l <= 8 {
+		var x uint64
+		pb := c.padBytes(8, b.Bytes())
+		br := bytes.NewReader(pb)
+		_ = binary.Read(br, binary.LittleEndian, &x)
+		i = x
+	}
+
 	return i
 }
 
+func (c *Conn) padBytes(l int, b []byte) []byte {
+	bl := len(b)
+	pl := l - bl
+	for i := 0; i < pl; i++ {
+		b = append(b, NullByte)
+	}
+
+	return b
+}
+
 func (c *Conn) encFixedLenInt(v uint64, l uint64) []byte {
 	b := make([]byte, 8)
 	binary.LittleEndian.PutUint64(b, v)

+ 5 - 1
binlog/handshake.go

@@ -2,6 +2,7 @@ package binlog
 
 import (
 	"bytes"
+	"fmt"
 )
 
 type Capabilities struct {
@@ -111,7 +112,8 @@ func (c *Conn) decodeHandshakePacket() error {
 	c.decodeCapabilityFlags(&packet)
 	packet.AuthPluginDataLength = c.getInt(TypeFixedInt, 1)
 	c.discardBytes(10)
-	packet.AuthPluginDataPart2 = c.readBytes(packet.AuthPluginDataLength - 8)
+	p1l := uint64(packet.AuthPluginDataPart1.Len())
+	packet.AuthPluginDataPart2 = c.readBytes(packet.AuthPluginDataLength - p1l)
 	packet.AuthPluginName = c.getString(TypeNullTerminatedString, 0)
 
 	err := c.scanner.Err()
@@ -154,6 +156,8 @@ func (c *Conn) writeHandshakeResponse() error {
 		c.putString(t, hr.ClientPluginName)
 	}
 
+	fmt.Printf("%+v\n", hr)
+
 	if c.Flush() != nil {
 		return c.Flush()
 	}