Просмотр исходного кода

Fixed scanning behavior to EOF

Josh Brickner 7 лет назад
Родитель
Сommit
ec48514fdc
3 измененных файлов с 28 добавлено и 38 удалено
  1. 2 11
      binlog/authentication.go
  2. 26 25
      binlog/connection.go
  3. 0 2
      binlog/handshake.go

+ 2 - 11
binlog/authentication.go

@@ -12,22 +12,13 @@ const SHA2_PERFORM_FULL_AUTHENTICATION = 0x04
 
 
 type AuthMoreDataPacket struct {
 type AuthMoreDataPacket struct {
 	PacketHeader
 	PacketHeader
-	Data string
+	Data uint64
 }
 }
 
 
 func (c *Conn) decodeAuthMoreDataResponsePacket(ph PacketHeader) (*AuthMoreDataPacket, error) {
 func (c *Conn) decodeAuthMoreDataResponsePacket(ph PacketHeader) (*AuthMoreDataPacket, error) {
 	md := AuthMoreDataPacket{}
 	md := AuthMoreDataPacket{}
 	md.PacketHeader = ph
 	md.PacketHeader = ph
-	flag := c.getInt(TypeFixedInt, 1)
-
-	switch flag {
-	case SHA2_FAST_AUTH_SUCCESS:
-		md.Data = "SHA2_FAST_AUTH_SUCCESS"
-	case SHA2_REQUEST_PUBLIC_KEY:
-		md.Data = "SHA2_REQUEST_PUBLIC_KEY"
-	case SHA2_PERFORM_FULL_AUTHENTICATION:
-		md.Data = "SHA2_PERFORM_FULL_AUTHENTICATION"
-	}
+	md.Data = c.getInt(TypeFixedInt, 1)
 
 
 	err := c.scanner.Err()
 	err := c.scanner.Err()
 	if err != nil {
 	if err != nil {

+ 26 - 25
binlog/connection.go

@@ -10,7 +10,6 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"math"
 	"math"
 	"net"
 	"net"
@@ -50,14 +49,6 @@ type Config struct {
 	Timeout    time.Duration
 	Timeout    time.Duration
 }
 }
 
 
-func splitByBytesFunc(data []byte, atEOF bool) (advance int, token []byte, err error) {
-	if atEOF {
-		return 0, nil, io.EOF
-	}
-
-	return 1, data[:1], nil
-}
-
 func newBinlogConfig(dsn string) (*Config, error) {
 func newBinlogConfig(dsn string) (*Config, error) {
 	var err error
 	var err error
 
 
@@ -138,7 +129,7 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 
 
 	c.HandshakeResponse = c.NewHandshakeResponse()
 	c.HandshakeResponse = c.NewHandshakeResponse()
 
 
-	// Send SSL_Request Packet
+	// If we are on SSL send SSL_Request packet now
 	if c.Config.SSL {
 	if c.Config.SSL {
 		err = c.writeSSLRequestPacket()
 		err = c.writeSSLRequestPacket()
 		if err != nil {
 		if err != nil {
@@ -168,9 +159,7 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 }
 }
 
 
 func (c *Conn) listen() error {
 func (c *Conn) listen() error {
-	fmt.Println("LISTEN")
 	ph, err := c.getPacketHeader()
 	ph, err := c.getPacketHeader()
-
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -178,10 +167,24 @@ func (c *Conn) listen() error {
 	switch ph.Status {
 	switch ph.Status {
 	case 0x01:
 	case 0x01:
 		fmt.Println("IN: AuthMoreDate PACKET")
 		fmt.Println("IN: AuthMoreDate PACKET")
-		_, err := c.decodeAuthMoreDataResponsePacket(ph)
+		md, err := c.decodeAuthMoreDataResponsePacket(ph)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
+
+		switch md.Data {
+		case SHA2_FAST_AUTH_SUCCESS:
+			fmt.Println("FAST AUTH")
+		case SHA2_REQUEST_PUBLIC_KEY:
+			fmt.Println("REQUEST PUBLIC KEY")
+		case SHA2_PERFORM_FULL_AUTHENTICATION:
+			fmt.Println("FULL AUTH")
+			c.putBytes(append([]byte(c.Config.Pass), NullByte))
+			if c.Flush() != nil {
+				return c.Flush()
+			}
+		}
+
 	case 0x00:
 	case 0x00:
 		fmt.Println("IN: OK PACKET")
 		fmt.Println("IN: OK PACKET")
 	case 0xFE:
 	case 0xFE:
@@ -195,9 +198,6 @@ func (c *Conn) listen() error {
 
 
 		err = errors.New(fmt.Sprintf("Error %d: %s", ep.ErrorCode, ep.ErrorMessage))
 		err = errors.New(fmt.Sprintf("Error %d: %s", ep.ErrorCode, ep.ErrorMessage))
 		return err
 		return err
-	default:
-		fmt.Printf("ph = %+v\n", ph)
-		fmt.Println("IN: UNKNOWN PACKET")
 	}
 	}
 
 
 	err = c.scanner.Err()
 	err = c.scanner.Err()
@@ -240,7 +240,11 @@ func init() {
 func (c *Conn) readBytes(l uint64) *bytes.Buffer {
 func (c *Conn) readBytes(l uint64) *bytes.Buffer {
 	b := make([]byte, 0)
 	b := make([]byte, 0)
 	for i := uint64(0); i < l; i++ {
 	for i := uint64(0); i < l; i++ {
-		c.scanner.Scan()
+		didScan := c.scanner.Scan()
+		if !didScan {
+			return nil
+		}
+
 		b = append(b, c.scanner.Bytes()...)
 		b = append(b, c.scanner.Bytes()...)
 	}
 	}
 
 
@@ -257,13 +261,9 @@ func (c *Conn) getBytesUntilEOF() *bytes.Buffer {
 			break
 			break
 		}
 		}
 
 
-		s = c.readBytes(uint64(l))
-
-		err := c.scanner.Err()
-		if err == io.EOF {
+		s := c.readBytes(uint64(l))
+		if s == nil {
 			return bytes.NewBuffer(b)
 			return bytes.NewBuffer(b)
-		} else if err != nil {
-			panic(err)
 		}
 		}
 
 
 		b = append(b, s.Bytes()...)
 		b = append(b, s.Bytes()...)
@@ -556,10 +556,11 @@ func (c *Conn) Flush() error {
 		return c.err
 		return c.err
 	}
 	}
 
 
+	fmt.Println(string(c.writeBuf.Bytes()))
 	c.writeBuf = c.addHeader()
 	c.writeBuf = c.addHeader()
 	_, _ = c.buffer.Write(c.writeBuf.Bytes())
 	_, _ = c.buffer.Write(c.writeBuf.Bytes())
 
 
-	// log all packets
+	// log all outgoing packets
 	fmt.Printf(
 	fmt.Printf(
 		"\nOUT:\n%08b\n%x\n%s\n\n",
 		"\nOUT:\n%08b\n%x\n%s\n\n",
 		c.writeBuf.Bytes(),
 		c.writeBuf.Bytes(),
@@ -626,5 +627,5 @@ func (c *Conn) setConnection(nc net.Conn) {
 	)
 	)
 
 
 	c.scanner = bufio.NewScanner(c.buffer.Reader)
 	c.scanner = bufio.NewScanner(c.buffer.Reader)
-	c.scanner.Split(splitByBytesFunc)
+	c.scanner.Split(bufio.ScanBytes)
 }
 }

+ 0 - 2
binlog/handshake.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"bytes"
 	"crypto/tls"
 	"crypto/tls"
 	"crypto/x509"
 	"crypto/x509"
-	"fmt"
 	"io/ioutil"
 	"io/ioutil"
 )
 )
 
 
@@ -247,7 +246,6 @@ func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
 // generate TLS config for client side
 // generate TLS config for client side
 // if insecureSkipVerify is set to true, serverName will not be validated
 // if insecureSkipVerify is set to true, serverName will not be validated
 func NewClientTLSConfig(keyPem string, cerPem string, caPem []byte, insecureSkipVerify bool, serverName string) *tls.Config {
 func NewClientTLSConfig(keyPem string, cerPem string, caPem []byte, insecureSkipVerify bool, serverName string) *tls.Config {
-	fmt.Printf("insecureSkipVerify = %+v\n", insecureSkipVerify)
 	config := &tls.Config{
 	config := &tls.Config{
 		InsecureSkipVerify: !insecureSkipVerify,
 		InsecureSkipVerify: !insecureSkipVerify,
 		ServerName:         serverName,
 		ServerName:         serverName,