Browse Source

Completed refactor of write to buffer code.

Josh Brickner 7 năm trước cách đây
mục cha
commit
81cfff8b5d
2 tập tin đã thay đổi với 201 bổ sung77 xóa
  1. 166 22
      binlog/connection.go
  2. 35 55
      binlog/handshake.go

+ 166 - 22
binlog/connection.go

@@ -7,9 +7,13 @@ import (
 	"database/sql/driver"
 	"encoding/binary"
 	"encoding/json"
+	"errors"
 	"fmt"
+	"math"
+	"math/bits"
 	"net"
 	"reflect"
+	"strings"
 	"time"
 )
 
@@ -17,8 +21,7 @@ import (
 const TypeNullTerminatedString = int(0)
 const TypeFixedString = int(1)
 const TypeFixedInt = int(2)
-
-//const TypeLenEncodedInt = int(3)
+const TypeLenEncInt = int(3)
 
 // Integer Maximums
 const MaxUint8 = 1<<8 - 1
@@ -42,12 +45,10 @@ type Config struct {
 }
 
 func splitByBytesFunc(data []byte, atEOF bool) (advance int, token []byte, err error) {
-	//fmt.Printf("DATA: %08b\n atEOF: %08b\n", data, atEOF)
 	if atEOF {
-		return 0, nil, nil
+		return 0, nil, errors.New("scanner found EOF")
 	}
 
-	//fmt.Printf("RETURN DATA: %+v\n", data[:1])
 	return 1, data[:1], nil
 }
 
@@ -61,16 +62,20 @@ func newBinlogConfig(dsn string) (*Config, error) {
 }
 
 type Conn struct {
-	Config    *Config
-	tcpConn   *net.TCPConn
-	Handshake *Handshake
-	buffer    *bufio.ReadWriter
-	scanner   *bufio.Scanner
+	Config     *Config
+	tcpConn    *net.TCPConn
+	Handshake  *Handshake
+	buffer     *bufio.ReadWriter
+	scanner    *bufio.Scanner
+	err        error
+	sequenceId uint64
+	writeBuf   *bytes.Buffer
 }
 
 func newBinlogConn(config *Config) Conn {
 	return Conn{
-		Config: config,
+		Config:     config,
+		sequenceId: 0,
 	}
 }
 
@@ -114,7 +119,7 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 		return nil, err
 	}
 
-	err = c.encodeHandshakeResponse()
+	err = c.writeHandshakeResponse()
 	if err != nil {
 		return nil, err
 	}
@@ -147,7 +152,6 @@ func (c *Conn) readBytes(l uint64) *bytes.Buffer {
 }
 
 func (c *Conn) getBytesUntilNull() *bytes.Buffer {
-
 	l := uint64(1)
 	s := c.readBytes(l)
 	b := s.Bytes()
@@ -161,7 +165,7 @@ func (c *Conn) getBytesUntilNull() *bytes.Buffer {
 		b = append(b, s.Bytes()...)
 	}
 
-	return bytes.NewBuffer(b[:len(b)-1])
+	return bytes.NewBuffer(b)
 }
 
 func (c *Conn) discardBytes(l int) {
@@ -200,7 +204,7 @@ func (c *Conn) getString(t int, l uint64) string {
 
 func (c *Conn) decNullTerminatedString() string {
 	b := c.getBytesUntilNull()
-	return b.String()
+	return strings.TrimRight(b.String(), string(NullByte))
 }
 
 func (c *Conn) decFixedString(l uint64) string {
@@ -209,18 +213,16 @@ func (c *Conn) decFixedString(l uint64) string {
 }
 
 func (c *Conn) decFixedInt(l uint64) uint64 {
-	b := c.readBytes(l)
-
 	var i uint64
+	b := c.readBytes(l)
 	i, _ = binary.ReadUvarint(b)
-
 	return i
 }
 
-func (c *Conn) encFixedLenInt(l uint64, v uint64) []byte {
-	b := make([]byte, 4)
+func (c *Conn) encFixedLenInt(v uint64, l uint64) []byte {
+	b := make([]byte, 8)
 	binary.LittleEndian.PutUint64(b, v)
-	return b[:(l - 1)]
+	return b[:l]
 }
 
 func (c *Conn) encLenEncInt(v uint64) []byte {
@@ -247,7 +249,9 @@ func (c *Conn) encLenEncInt(v uint64) []byte {
 		binary.LittleEndian.PutUint64(b, uint64(v))
 	}
 
-	b = append(prefix, b...)
+	if len(b) > 1 {
+		b = append(prefix, b...)
+	}
 	return b
 }
 
@@ -282,3 +286,143 @@ func (c *Conn) bitmaskToStruct(b []byte, s interface{}) interface{} {
 
 	return v.Interface()
 }
+
+func (c *Conn) structToBitmask(s interface{}) []byte {
+	t := reflect.TypeOf(s).Elem()
+	sV := reflect.ValueOf(s).Elem()
+	fC := uint(t.NumField())
+	m := uint64(0)
+	for i := uint(0); i < fC; i++ {
+		f := sV.Field(int(i))
+		v := f.Bool()
+		if v {
+			m |= 1 << i
+		}
+	}
+
+	l := uint64(math.Ceil(float64(fC) / 8.0))
+	b := make([]byte, 8)
+	binary.BigEndian.PutUint64(b, bits.Reverse64(m))
+
+	switch {
+	case l > 4: // 64 bits
+		b = b[:8]
+	case l > 2: // 32 bits
+		b = b[:4]
+	case l > 1: // 16 bits
+		b = b[:2]
+	default: // 8 bits
+		b = b[:1]
+	}
+
+	return b
+}
+
+func (c *Conn) putString(t int, v string) uint64 {
+	b := make([]byte, 0)
+
+	switch t {
+	case TypeFixedString:
+		b = c.encFixedString(v)
+	case TypeNullTerminatedString:
+		b = c.encNullTerminatedString(v)
+	}
+
+	l, err := c.writeBuf.Write(b)
+	if err != nil {
+		c.err = err
+	}
+
+	return uint64(l)
+}
+
+func (c *Conn) encNullTerminatedString(v string) []byte {
+	return append([]byte(v), NullByte)
+}
+
+func (c *Conn) encFixedString(v string) []byte {
+	return []byte(v)
+}
+
+func (c *Conn) putInt(t int, v uint64, l uint64) uint64 {
+	c.setupWriteBuffer()
+
+	b := make([]byte, 0)
+
+	switch t {
+	case TypeFixedInt:
+		b = c.encFixedLenInt(v, l)
+	case TypeLenEncInt:
+		b = c.encLenEncInt(v)
+	}
+
+	n, err := c.writeBuf.Write(b)
+	if err != nil {
+		c.err = err
+	}
+
+	return uint64(n)
+}
+
+func (c *Conn) putNullBytes(n uint64) uint64 {
+	c.setupWriteBuffer()
+
+	b := make([]byte, n)
+	l, err := c.writeBuf.Write(b)
+	if err != nil {
+		c.err = err
+	}
+
+	return uint64(l)
+}
+
+func (c *Conn) putBytes(v []byte) uint64 {
+	c.setupWriteBuffer()
+
+	l, err := c.writeBuf.Write(v)
+	if err != nil {
+		c.err = err
+	}
+
+	return uint64(l)
+}
+
+func (c *Conn) Flush() error {
+	if c.err != nil {
+		return c.err
+	}
+
+	c.writeBuf = c.addHeader()
+	_, _ = c.buffer.Write(c.writeBuf.Bytes())
+	if c.buffer.Flush() != nil {
+		return c.buffer.Flush()
+	}
+
+	return nil
+}
+
+func (c *Conn) addHeader() *bytes.Buffer {
+	pl := uint64(c.writeBuf.Len()) + 4
+	sId := uint64(c.sequenceId)
+	c.sequenceId++
+
+	plB := c.encFixedLenInt(pl, 3)
+	sIdB := c.encFixedLenInt(sId, 1)
+
+	return bytes.NewBuffer(append(append(plB, sIdB...), c.writeBuf.Bytes()...))
+}
+
+func (c *Conn) setupWriteBuffer() {
+	if c.writeBuf == nil {
+		c.writeBuf = bytes.NewBuffer(nil)
+	}
+}
+
+func Reverse(s string) string {
+	var b strings.Builder
+	b.Grow(len(s))
+	for i := len(s) - 1; i >= 0; i-- {
+		b.WriteByte(s[i])
+	}
+	return b.String()
+}

+ 35 - 55
binlog/handshake.go

@@ -2,7 +2,6 @@ package binlog
 
 import (
 	"bytes"
-	"encoding/binary"
 )
 
 type Capabilities struct {
@@ -115,8 +114,9 @@ func (c *Conn) decodeHandshakePacket() error {
 	packet.AuthPluginDataPart2 = c.readBytes(packet.AuthPluginDataLength - 8)
 	packet.AuthPluginName = c.getString(TypeNullTerminatedString, 0)
 
-	if c.scanner.Err() != nil {
-		return c.scanner.Err()
+	err := c.scanner.Err()
+	if err != nil {
+		return err
 	}
 
 	c.Handshake = &packet
@@ -124,70 +124,50 @@ func (c *Conn) decodeHandshakePacket() error {
 	return nil
 }
 
-func (c *Conn) encodeHandshakeResponse() []byte {
-	hr := NewHandshakeResponse()
-	buf := bytes.NewBuffer(make([]byte, 0))
-
-	// Capabilities flag.
-	var cf capability = 0
-
-	// Write Capability Flags.
-	buf.Write([]byte(cf))
-
-	// Write MaxPacketSize
-	buf.Write(MaxPacketSize)
-
-	// Write CharacterSet
-	cs := make([]byte, 2)
-	binary.LittleEndian.PutUint16(cs, uint16(hr.CharacterSet))
-	buf.Write(cs[:1])
-
-	// Write Filler
-	buf.Write(make([]byte, 23))
-
-	// Write username
-	u := append([]byte(hr.Username), NullByte)
-	buf.Write(u)
+func (c *Conn) writeHandshakeResponse() error {
+	hr := c.NewHandshakeResponse()
+	cf := c.structToBitmask(hr.ClientFlag)
+	c.putBytes(cf)
+	c.putInt(TypeFixedInt, MaxPacketSize, 4)
+	c.putInt(TypeFixedInt, hr.CharacterSet, 1)
+	c.putNullBytes(23)
+	c.putString(TypeNullTerminatedString, hr.Username)
 
 	salt := append(c.Handshake.AuthPluginDataPart1.Bytes(), c.Handshake.AuthPluginDataPart2.Bytes()...)
 	ar := c.cachingSha2Auth(salt, []byte(hr.AuthResponse))
 	if hr.ClientFlag.PluginAuthLenEncClientData {
-		buf.Write(c.encLenEncInt(uint64(len(ar))))
-		buf.Write(ar)
+		c.putInt(TypeLenEncInt, uint64(len(ar)), 0)
+		c.putBytes(ar)
 	} else if hr.ClientFlag.SecureConnection {
-		l := make([]byte, 2)
-		binary.LittleEndian.PutUint16(l, uint16(len(ar)))
-		buf.Write(l[:1])
-		buf.Write(ar)
+		c.putInt(TypeFixedInt, uint64(len(ar)), 1)
+		c.putBytes(ar)
 	} else {
-		buf.Write(append(ar, NullByte))
+		c.putString(TypeNullTerminatedString, c.Config.Pass)
 	}
 
 	// Write database name
 	if hr.ClientFlag.ConnectWithDB {
-		buf.Write(append([]byte(hr.Database), NullByte))
+		c.putString(TypeNullTerminatedString, hr.Database)
 	}
 
 	// Write auth plugin
 	if hr.ClientFlag.PluginAuth {
-		buf.Write([]byte(hr.ClientPluginName))
+		c.putString(TypeNullTerminatedString, hr.ClientPluginName)
 	}
 
-	pl := make([]byte, 4)
-	binary.LittleEndian.PutUint32(pl, uint32(buf.Len()))
-	p := append(pl[:3], 1)
-	p = append(p, buf.Bytes()...)
-	buf = bytes.NewBuffer(p)
+	if c.Flush() != nil {
+		return c.Flush()
+	}
 
-	return buf.Bytes()
+	return nil
 }
 
-func NewHandshakeResponse() *HandshakeResponse {
+func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
 	return &HandshakeResponse{
 		ClientFlag: &Capabilities{
 			LongPassword:               true,
 			FoundRows:                  true,
-			LongFlag:                   true,
+			LongFlag:                   false,
 			ConnectWithDB:              true,
 			NoSchema:                   false,
 			Compress:                   false,
@@ -200,27 +180,27 @@ func NewHandshakeResponse() *HandshakeResponse {
 			IgnoreSigpipe:              false,
 			Transactions:               true,
 			LegacyProtocol41:           false,
-			SecureConnection:           true,
+			SecureConnection:           false,
 			MultiStatements:            false,
 			MultiResults:               false,
 			PSMultiResults:             true,
 			PluginAuth:                 false,
 			ConnectAttrs:               false,
-			PluginAuthLenEncClientData: false,
+			PluginAuthLenEncClientData: true,
 			CanHandleExpiredPasswords:  false,
-			SessionTrack:               true,
-			DeprecateEOF:               true,
+			SessionTrack:               false,
+			DeprecateEOF:               false,
 			SSLVerifyServerCert:        false,
-			OptionalResultSetMetadata:  true,
-			RememberOptions:            true,
+			OptionalResultSetMetadata:  false,
+			RememberOptions:            false,
 		},
 		MaxPacketSize:      MaxPacketSize,
 		CharacterSet:       45,
-		Username:           "",
-		AuthResponseLength: 0,
-		AuthResponse:       "",
-		Database:           "",
-		ClientPluginName:   "",
+		Username:           c.Config.User,
+		AuthResponseLength: uint64(len(c.Config.Pass)),
+		AuthResponse:       c.Config.Pass,
+		Database:           c.Config.Database,
+		ClientPluginName:   c.Handshake.AuthPluginName,
 		KeyValues:          nil,
 	}
 }