Ver código fonte

Completed refactor to Scanner API

Josh Brickner 7 anos atrás
pai
commit
dd09af1190
2 arquivos alterados com 70 adições e 54 exclusões
  1. 47 32
      binlog/connection.go
  2. 23 22
      binlog/handshake.go

+ 47 - 32
binlog/connection.go

@@ -41,6 +41,16 @@ type Config struct {
 	Timeout    time.Duration
 }
 
+func splitByBytesFunc(data []byte, atEOF bool) (advance int, token []byte, err error) {
+	//fmt.Printf("DATA: %08b\n atEOF: %08b\n", data, atEOF)
+	if atEOF {
+		return 0, nil, nil
+	}
+
+	//fmt.Printf("RETURN DATA: %+v\n", data[:1])
+	return 1, data[:1], nil
+}
+
 func newBinlogConfig(dsn string) (*Config, error) {
 	var err error
 
@@ -55,6 +65,7 @@ type Conn struct {
 	tcpConn   *net.TCPConn
 	Handshake *Handshake
 	buffer    *bufio.ReadWriter
+	scanner   *bufio.Scanner
 }
 
 func newBinlogConn(config *Config) Conn {
@@ -99,11 +110,14 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 	}
 
 	err = c.decodeHandshakePacket()
-	fmt.Printf("%+v", c.Handshake)
-	//b := c.encodeHandshakeResponse()
-	//fmt.Printf("%08b\n%d\n%s", b, b, b)
+	if err != nil {
+		return nil, err
+	}
 
-	//_, err = c.tcpConn.Write(b)
+	err = c.encodeHandshakeResponse()
+	if err != nil {
+		return nil, err
+	}
 
 	return c, err
 }
@@ -112,47 +126,48 @@ func init() {
 	sql.Register("mysql-binlog", &Driver{})
 }
 
-func (c *Conn) getPacketLength() uint64 {
-	l := c.getInt(TypeFixedInt, 3)
-	return l
-}
-
-func (c *Conn) readWholePacket() (*bytes.Buffer, error) {
-	pl := c.getPacketLength()
-	b, err := c.readBytes(pl - 3)
-	return b, err
-}
-
-func (c *Conn) readBytes(l uint64) (*bytes.Buffer, error) {
+func (c *Conn) readBytes(l uint64) *bytes.Buffer {
 	if c.buffer == nil {
 		c.buffer = bufio.NewReadWriter(
 			bufio.NewReader(c.tcpConn),
 			bufio.NewWriter(c.tcpConn),
 		)
-	}
 
-	b := make([]byte, l)
-	_, err := c.buffer.Read(b)
-	if err != nil {
-		return nil, err
+		c.scanner = bufio.NewScanner(c.buffer.Reader)
+		c.scanner.Split(splitByBytesFunc)
 	}
 
-	return bytes.NewBuffer(b), nil
-}
+	b := make([]byte, 0)
+	for i := uint64(0); i < l; i++ {
+		c.scanner.Scan()
+		b = append(b, c.scanner.Bytes()...)
+	}
 
-func (c *Conn) getBytes(l uint64) *bytes.Buffer {
-	b := make([]byte, l)
-	_, _ = c.buffer.Read(b)
 	return bytes.NewBuffer(b)
 }
 
 func (c *Conn) getBytesUntilNull() *bytes.Buffer {
-	s, _ := c.buffer.ReadBytes(NullByte)
-	return bytes.NewBuffer(s)
+
+	l := uint64(1)
+	s := c.readBytes(l)
+	b := s.Bytes()
+
+	for true {
+		if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
+			break
+		}
+
+		s = c.readBytes(uint64(l))
+		b = append(b, s.Bytes()...)
+	}
+
+	return bytes.NewBuffer(b[:len(b)-1])
 }
 
 func (c *Conn) discardBytes(l int) {
-	_, _ = c.buffer.Discard(l)
+	for i := 0; i < l; i++ {
+		c.scanner.Scan()
+	}
 }
 
 func (c *Conn) getInt(t int, l uint64) uint64 {
@@ -185,16 +200,16 @@ func (c *Conn) getString(t int, l uint64) string {
 
 func (c *Conn) decNullTerminatedString() string {
 	b := c.getBytesUntilNull()
-	return string(b.Bytes())
+	return b.String()
 }
 
 func (c *Conn) decFixedString(l uint64) string {
-	b, _ := c.readBytes(l)
+	b := c.readBytes(l)
 	return b.String()
 }
 
 func (c *Conn) decFixedInt(l uint64) uint64 {
-	b, _ := c.readBytes(l)
+	b := c.readBytes(l)
 
 	var i uint64
 	i, _ = binary.ReadUvarint(b)

+ 23 - 22
binlog/handshake.go

@@ -3,8 +3,6 @@ package binlog
 import (
 	"bytes"
 	"encoding/binary"
-	"fmt"
-	"os"
 )
 
 type Capabilities struct {
@@ -61,13 +59,13 @@ type Handshake struct {
 	ProtocolVersion      uint64
 	ServerVersion        string
 	ThreadID             uint64
-	AuthPluginDataPart1  []byte
-	CapabilityFlags1     []byte
+	AuthPluginDataPart1  *bytes.Buffer
+	CapabilityFlags1     *bytes.Buffer
 	Charset              uint64
-	StatusFlags          []byte
-	CapabilityFlags2     []byte
+	StatusFlags          *bytes.Buffer
+	CapabilityFlags2     *bytes.Buffer
 	AuthPluginDataLength uint64
-	AuthPluginDataPart2  []byte
+	AuthPluginDataPart2  *bytes.Buffer
 	AuthPluginName       string
 	Capabilities         *Capabilities
 	Status               *Status
@@ -86,41 +84,44 @@ type HandshakeResponse struct {
 }
 
 func (c *Conn) decodeCapabilityFlags(hs *Handshake) {
-	var cfb = append(hs.CapabilityFlags1, hs.CapabilityFlags2...)
+	var cfb = append(hs.CapabilityFlags1.Bytes(), hs.CapabilityFlags2.Bytes()...)
 	capabilities := c.bitmaskToStruct(cfb, hs.Capabilities).(Capabilities)
 	hs.Capabilities = &capabilities
 }
 
 func (c *Conn) decodeStatusFlags(hs *Handshake) {
-	status := c.bitmaskToStruct(hs.StatusFlags, hs.Status).(Status)
+	status := c.bitmaskToStruct(hs.StatusFlags.Bytes(), hs.Status).(Status)
 	hs.Status = &status
 }
 
 func (c *Conn) decodeHandshakePacket() error {
 	packet := Handshake{}
-	err := c.readWholePacket()
-	fmt.Println("\nEND")
-	os.Exit(0)
+
 	packet.PacketLength = c.getInt(TypeFixedInt, 3)
 	packet.SequenceID = c.getInt(TypeFixedInt, 1)
 	packet.ProtocolVersion = c.getInt(TypeFixedInt, 1)
 	packet.ServerVersion = c.getString(TypeNullTerminatedString, 0)
 	packet.ThreadID = c.getInt(TypeFixedInt, 4)
-	packet.AuthPluginDataPart1 = c.getBytes(8).Bytes()
+	packet.AuthPluginDataPart1 = c.readBytes(8)
 	c.discardBytes(1)
-	packet.CapabilityFlags1 = c.getBytes(2).Bytes()
+	packet.CapabilityFlags1 = c.readBytes(2)
 	packet.Charset = c.getInt(TypeFixedInt, 1)
-	packet.StatusFlags = c.getBytes(2).Bytes()
+	packet.StatusFlags = c.readBytes(2)
 	c.decodeStatusFlags(&packet)
-	packet.CapabilityFlags2 = c.getBytes(2).Bytes()
+	packet.CapabilityFlags2 = c.readBytes(2)
 	c.decodeCapabilityFlags(&packet)
 	packet.AuthPluginDataLength = c.getInt(TypeFixedInt, 1)
 	c.discardBytes(10)
-	packet.AuthPluginDataPart2 = c.getBytes(packet.AuthPluginDataLength - 8).Bytes()
+	packet.AuthPluginDataPart2 = c.readBytes(packet.AuthPluginDataLength - 8)
 	packet.AuthPluginName = c.getString(TypeNullTerminatedString, 0)
+
+	if c.scanner.Err() != nil {
+		return c.scanner.Err()
+	}
+
 	c.Handshake = &packet
 
-	return err
+	return nil
 }
 
 func (c *Conn) encodeHandshakeResponse() []byte {
@@ -128,13 +129,13 @@ func (c *Conn) encodeHandshakeResponse() []byte {
 	buf := bytes.NewBuffer(make([]byte, 0))
 
 	// Capabilities flag.
-	//var cf capability = 0
+	var cf capability = 0
 
 	// Write Capability Flags.
-	//buf.Write([]byte(cf))
+	buf.Write([]byte(cf))
 
 	// Write MaxPacketSize
-	//buf.Write()
+	buf.Write(MaxPacketSize)
 
 	// Write CharacterSet
 	cs := make([]byte, 2)
@@ -148,7 +149,7 @@ func (c *Conn) encodeHandshakeResponse() []byte {
 	u := append([]byte(hr.Username), NullByte)
 	buf.Write(u)
 
-	salt := append(c.Handshake.AuthPluginDataPart1, c.Handshake.AuthPluginDataPart2...)
+	salt := append(c.Handshake.AuthPluginDataPart1.Bytes(), c.Handshake.AuthPluginDataPart2.Bytes()...)
 	ar := c.cachingSha2Auth(salt, []byte(hr.AuthResponse))
 	if hr.ClientFlag.PluginAuthLenEncClientData {
 		buf.Write(c.encLenEncInt(uint64(len(ar))))