Ver Fonte

Work on transition to using bytes.Reader

Josh Brickner há 7 anos atrás
pai
commit
0e3289a101
2 ficheiros alterados com 89 adições e 165 exclusões
  1. 67 85
      binlog/connection.go
  2. 22 80
      binlog/handshake.go

+ 67 - 85
binlog/connection.go

@@ -1,6 +1,7 @@
 package binlog
 
 import (
+	"bufio"
 	"bytes"
 	"database/sql"
 	"database/sql/driver"
@@ -16,7 +17,8 @@ import (
 const TypeNullTerminatedString = int(0)
 const TypeFixedString = int(1)
 const TypeFixedInt = int(2)
-const TypeLenEncodedInt = int(3)
+
+//const TypeLenEncodedInt = int(3)
 
 // Integer Maximums
 const MaxUint8 = 1<<8 - 1
@@ -52,6 +54,7 @@ type Conn struct {
 	Config    *Config
 	tcpConn   *net.TCPConn
 	Handshake *Handshake
+	buffer    *bufio.ReadWriter
 }
 
 func newBinlogConn(config *Config) Conn {
@@ -80,12 +83,12 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 		return nil, err
 	}
 
-	blConn := newBinlogConn(config)
+	c := newBinlogConn(config)
 
-	dialer := net.Dialer{Timeout: blConn.Config.Timeout}
-	addr := fmt.Sprintf("%s:%d", blConn.Config.Host, blConn.Config.Port)
-	c, err := dialer.Dial("tcp", addr)
-	blConn.tcpConn = c.(*net.TCPConn)
+	dialer := net.Dialer{Timeout: c.Config.Timeout}
+	addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
+	t, err := dialer.Dial("tcp", addr)
+	c.tcpConn = t.(*net.TCPConn)
 
 	if err != nil {
 		netErr, ok := err.(net.Error)
@@ -95,129 +98,108 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 		}
 	}
 
-	err = blConn.decodeHandshakePacket()
-	b := blConn.encodeHandshakeResponse()
-	fmt.Printf("%08b\n%d\n%s", b, b, b)
+	err = c.decodeHandshakePacket()
+	fmt.Printf("%+v", c.Handshake)
+	//b := c.encodeHandshakeResponse()
+	//fmt.Printf("%08b\n%d\n%s", b, b, b)
 
-	_, err = blConn.tcpConn.Write(b)
+	//_, err = c.tcpConn.Write(b)
 
-	return blConn, err
+	return c, err
 }
 
 func init() {
 	sql.Register("mysql-binlog", &Driver{})
 }
 
-func (c *Conn) getBytes(l uint64) ([]byte, error) {
-	b := make([]byte, l)
-	_, err := c.tcpConn.Read(b)
+func (c *Conn) getPacketLength() uint64 {
+	l := c.getInt(TypeFixedInt, 3)
+	return l
+}
 
+func (c *Conn) readWholePacket() (*bytes.Buffer, error) {
+	pl := c.getPacketLength()
+	b, err := c.readBytes(pl - 3)
 	return b, err
 }
 
-func (c *Conn) consumeBytes(l uint64) error {
+func (c *Conn) readBytes(l uint64) (*bytes.Buffer, error) {
+	if c.buffer == nil {
+		c.buffer = bufio.NewReadWriter(
+			bufio.NewReader(c.tcpConn),
+			bufio.NewWriter(c.tcpConn),
+		)
+	}
+
+	b := make([]byte, l)
+	_, err := c.buffer.Read(b)
+	if err != nil {
+		return nil, err
+	}
+
+	return bytes.NewBuffer(b), nil
+}
+
+func (c *Conn) getBytes(l uint64) *bytes.Buffer {
 	b := make([]byte, l)
-	_, err := c.tcpConn.Read(b)
+	_, _ = c.buffer.Read(b)
+	return bytes.NewBuffer(b)
+}
 
-	return err
+func (c *Conn) getBytesUntilNull() *bytes.Buffer {
+	s, _ := c.buffer.ReadBytes(NullByte)
+	return bytes.NewBuffer(s)
 }
 
-func (c *Conn) getInt(t int, l uint64) (uint64, error) {
+func (c *Conn) discardBytes(l int) {
+	_, _ = c.buffer.Discard(l)
+}
+
+func (c *Conn) getInt(t int, l uint64) uint64 {
 	var v uint64
-	var err error = nil
 
 	switch t {
 	case TypeFixedInt:
-		v, err = c.popFixedInt(l)
+		v = c.decFixedInt(l)
 	default:
 		v = 0
 	}
 
-	if err != nil {
-		return 0, err
-	}
-
-	return v, nil
+	return v
 }
 
-func (c *Conn) getString(t int, l uint64) (string, error) {
+func (c *Conn) getString(t int, l uint64) string {
 	var v string
-	var err error = nil
 
 	switch t {
 	case TypeFixedString:
-		v, err = c.popFixedString(l)
+		v = c.decFixedString(l)
 	case TypeNullTerminatedString:
-		v, err = c.popNullTerminatedString()
+		v = c.decNullTerminatedString()
 	default:
 		v = ""
 	}
 
-	if err != nil {
-		return "", err
-	}
-
-	return v, nil
-}
-
-func (c *Conn) readBytes(l uint64) (*bytes.Buffer, error) {
-	b := make([]byte, l)
-	_, err := c.tcpConn.Read(b)
-	if err != nil {
-		return nil, err
-	}
-
-	return bytes.NewBuffer(b), nil
-}
-
-func (c *Conn) readToNull() (*bytes.Buffer, error) {
-	var s []byte
-	for {
-		bA := make([]byte, 1)
-		_, err := c.tcpConn.Read(bA)
-		if err != nil {
-			return nil, err
-		}
-
-		b := bA[0]
-		if b == NullByte {
-			break
-		} else {
-			s = append(s, b)
-		}
-	}
-
-	return bytes.NewBuffer(s), nil
+	return v
 }
 
-func (c *Conn) popNullTerminatedString() (string, error) {
-	b, err := c.readToNull()
-	if err != nil {
-		return "", err
-	}
-
-	return string(b.Bytes()), nil
+func (c *Conn) decNullTerminatedString() string {
+	b := c.getBytesUntilNull()
+	return string(b.Bytes())
 }
 
-func (c *Conn) popFixedString(l uint64) (string, error) {
-	b, err := c.readBytes(l)
-	if err != nil {
-		return "", err
-	}
-
-	return string(b.Bytes()), nil
+func (c *Conn) decFixedString(l uint64) string {
+	b, _ := c.readBytes(l)
+	return b.String()
 }
 
-func (c *Conn) popFixedInt(l uint64) (uint64, error) {
-	b, err := c.readBytes(l)
-	if err != nil {
-		return 0, err
-	}
+func (c *Conn) decFixedInt(l uint64) uint64 {
+	b, _ := c.readBytes(l)
 
 	var i uint64
-	i, err = binary.ReadUvarint(b)
+	i, _ = binary.ReadUvarint(b)
 
-	return i, err
+	return i
 }
 
 func (c *Conn) encFixedLenInt(l uint64, v uint64) []byte {

+ 22 - 80
binlog/handshake.go

@@ -3,6 +3,8 @@ package binlog
 import (
 	"bytes"
 	"encoding/binary"
+	"fmt"
+	"os"
 )
 
 type Capabilities struct {
@@ -96,89 +98,29 @@ func (c *Conn) decodeStatusFlags(hs *Handshake) {
 
 func (c *Conn) decodeHandshakePacket() error {
 	packet := Handshake{}
-	var err error
-
-	packet.PacketLength, err = c.getInt(TypeFixedInt, 3)
-	if err != nil {
-		return err
-	}
-
-	packet.SequenceID, err = c.getInt(TypeFixedInt, 1)
-	if err != nil {
-		return err
-	}
-
-	packet.ProtocolVersion, err = c.getInt(TypeFixedInt, 1)
-	if err != nil {
-		return err
-	}
-
-	packet.ServerVersion, err = c.getString(TypeNullTerminatedString, 0)
-	if err != nil {
-		return err
-	}
-
-	packet.ThreadID, err = c.getInt(TypeFixedInt, 4)
-	if err != nil {
-		return err
-	}
-
-	packet.AuthPluginDataPart1, err = c.getBytes(8)
-	if err != nil {
-		return err
-	}
-
-	err = c.consumeBytes(1)
-	if err != nil {
-		return err
-	}
-
-	packet.CapabilityFlags1, err = c.getBytes(2)
-	if err != nil {
-		return err
-	}
-
-	packet.Charset, err = c.getInt(TypeFixedInt, 1)
-	if err != nil {
-		return err
-	}
-
-	packet.StatusFlags, err = c.getBytes(2)
-	if err != nil {
-		return err
-	}
-
+	err := c.readWholePacket()
+	fmt.Println("\nEND")
+	os.Exit(0)
+	packet.PacketLength = c.getInt(TypeFixedInt, 3)
+	packet.SequenceID = c.getInt(TypeFixedInt, 1)
+	packet.ProtocolVersion = c.getInt(TypeFixedInt, 1)
+	packet.ServerVersion = c.getString(TypeNullTerminatedString, 0)
+	packet.ThreadID = c.getInt(TypeFixedInt, 4)
+	packet.AuthPluginDataPart1 = c.getBytes(8).Bytes()
+	c.discardBytes(1)
+	packet.CapabilityFlags1 = c.getBytes(2).Bytes()
+	packet.Charset = c.getInt(TypeFixedInt, 1)
+	packet.StatusFlags = c.getBytes(2).Bytes()
 	c.decodeStatusFlags(&packet)
-
-	packet.CapabilityFlags2, err = c.getBytes(2)
-	if err != nil {
-		return err
-	}
-
+	packet.CapabilityFlags2 = c.getBytes(2).Bytes()
 	c.decodeCapabilityFlags(&packet)
-
-	packet.AuthPluginDataLength, err = c.getInt(TypeFixedInt, 1)
-	if err != nil {
-		return err
-	}
-
-	err = c.consumeBytes(10)
-	if err != nil {
-		return err
-	}
-
-	packet.AuthPluginDataPart2, err = c.getBytes(packet.AuthPluginDataLength - 8)
-	if err != nil {
-		return err
-	}
-
-	packet.AuthPluginName, err = c.getString(TypeNullTerminatedString, 0)
-	if err != nil {
-		return err
-	}
-
+	packet.AuthPluginDataLength = c.getInt(TypeFixedInt, 1)
+	c.discardBytes(10)
+	packet.AuthPluginDataPart2 = c.getBytes(packet.AuthPluginDataLength - 8).Bytes()
+	packet.AuthPluginName = c.getString(TypeNullTerminatedString, 0)
 	c.Handshake = &packet
-	return nil
+
+	return err
 }
 
 func (c *Conn) encodeHandshakeResponse() []byte {