Explorar o código

Implemented decoding of OK/EOF Packet

Josh Brickner %!s(int64=7) %!d(string=hai) anos
pai
achega
1bbc5e71af
Modificáronse 3 ficheiros con 83 adicións e 19 borrados
  1. 78 14
      binlog/connection.go
  2. 3 3
      binlog/handshake.go
  3. 2 2
      config.json

+ 78 - 14
binlog/connection.go

@@ -18,12 +18,17 @@ import (
 	"time"
 )
 
+// Misc. Constants
+const NullByte byte = 0
+const MaxPacketSize = MaxUint16
+
 // MySQL Packet Data Types
 const TypeNullTerminatedString = int(0)
 const TypeFixedString = int(1)
 const TypeFixedInt = int(2)
 const TypeLenEncInt = int(3)
 const TypeRestOfPacketString = int(4)
+const TypeLenEncString = int(5)
 
 // Integer Maximums
 const MaxUint08 = 1<<8 - 1
@@ -31,9 +36,11 @@ const MaxUint16 = 1<<16 - 1
 const MaxUint24 = 1<<24 - 1
 const MaxUint64 = 1<<64 - 1
 
-// Misc. Constants
-const NullByte byte = 0
-const MaxPacketSize = MaxUint16
+// Packet Statuses
+const StatusOK = 0x00
+const StatusEOF = 0xFE
+const StatusErr = 0xFF
+const StatusAuth = 0x01
 
 type Config struct {
 	Host       string `json:"host"`
@@ -75,6 +82,7 @@ type Conn struct {
 	err               error
 	sequenceId        uint64
 	writeBuf          *bytes.Buffer
+	StausFlags        *StatusFlags
 }
 
 func newBinlogConn(config *Config) Conn {
@@ -166,7 +174,7 @@ func (c *Conn) listen() error {
 	c.sequenceId++
 
 	switch ph.Status {
-	case 0x01:
+	case StatusAuth:
 		fmt.Println("IN: AuthMoreDate PACKET")
 		md, err := c.decodeAuthMoreDataResponsePacket(ph)
 		if err != nil {
@@ -186,11 +194,15 @@ func (c *Conn) listen() error {
 			}
 		}
 
-	case 0x00:
+	case StatusEOF:
+		fallthrough
+	case StatusOK:
 		fmt.Println("IN: OK PACKET")
-	case 0xFE:
-		fmt.Println("IN: EOF PACKET")
-	case 0xFF:
+		_, err := c.decodeOKPacket(ph)
+		if err != nil {
+			return err
+		}
+	case StatusErr:
 		fmt.Println("IN: ERROR PACKET")
 		ep, err := c.decodeErrorPacket(ph)
 		if err != nil {
@@ -303,6 +315,8 @@ func (c *Conn) getInt(t int, l uint64) uint64 {
 	switch t {
 	case TypeFixedInt:
 		v = c.decFixedInt(l)
+	case TypeLenEncInt:
+		v = c.decLenEncInt()
 	default:
 		v = 0
 	}
@@ -316,6 +330,8 @@ func (c *Conn) getString(t int, l uint64) string {
 	switch t {
 	case TypeFixedString:
 		v = c.decFixedString(l)
+	case TypeLenEncString:
+		v = string(c.decLenEncInt())
 	case TypeNullTerminatedString:
 		v = c.decNullTerminatedString()
 	case TypeRestOfPacketString:
@@ -342,6 +358,18 @@ func (c *Conn) decFixedString(l uint64) string {
 	return b.String()
 }
 
+func (c *Conn) decLenEncInt() uint64 {
+	var l uint16
+	b := c.readBytes(1)
+	br := bytes.NewReader(b.Bytes())
+	_ = binary.Read(br, binary.LittleEndian, &l)
+	if l > 0 {
+		return c.decFixedInt(uint64(l))
+	} else {
+		return 0
+	}
+}
+
 func (c *Conn) decFixedInt(l uint64) uint64 {
 	var i uint64
 	b := c.readBytes(l)
@@ -562,12 +590,12 @@ func (c *Conn) Flush() error {
 	_, _ = c.buffer.Write(c.writeBuf.Bytes())
 
 	// log all outgoing packets
-	fmt.Printf(
-		"\nOUT:\n%08b\n%x\n%s\n\n",
-		c.writeBuf.Bytes(),
-		c.writeBuf.Bytes(),
-		c.writeBuf.Bytes(),
-	)
+	// fmt.Printf(
+	// 	"\nOUT:\n%08b\n%x\n%s\n\n",
+	// 	c.writeBuf.Bytes(),
+	// 	c.writeBuf.Bytes(),
+	// 	c.writeBuf.Bytes(),
+	// )
 
 	if c.buffer.Flush() != nil {
 		return c.buffer.Flush()
@@ -595,6 +623,42 @@ func (c *Conn) setupWriteBuffer() {
 	}
 }
 
+type StatusFlags struct {
+}
+
+type OKPacket struct {
+	PacketHeader
+	Header           uint64
+	AffectedRows     uint64
+	LastInsertID     uint64
+	StatusFlags      uint64
+	Warnings         uint64
+	Info             string
+	SessionStateInfo string
+}
+
+func (c *Conn) decodeOKPacket(ph PacketHeader) (*OKPacket, error) {
+	op := OKPacket{}
+	op.PacketHeader = ph
+	op.Header = ph.Status
+	op.AffectedRows = c.getInt(TypeLenEncInt, 0)
+	op.LastInsertID = c.getInt(TypeLenEncInt, 0)
+	if c.HandshakeResponse.ClientFlag.Protocol41 {
+		op.StatusFlags = c.getInt(TypeFixedInt, 2)
+		op.Warnings = c.getInt(TypeFixedInt, 1)
+	} else if c.HandshakeResponse.ClientFlag.Transactions {
+		op.StatusFlags = c.getInt(TypeFixedInt, 2)
+	}
+
+	if c.HandshakeResponse.ClientFlag.SessionTrack {
+		op.Info = c.getString(TypeRestOfPacketString, 0)
+	} else {
+		op.Info = c.getString(TypeRestOfPacketString, 0)
+	}
+
+	return &op, nil
+}
+
 type ErrorPacket struct {
 	PacketHeader
 	ErrorCode      uint64

+ 3 - 3
binlog/handshake.go

@@ -216,17 +216,17 @@ func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
 			Interactive:                true,
 			SSL:                        c.Config.SSL,
 			IgnoreSigpipe:              false,
-			Transactions:               true,
+			Transactions:               c.Handshake.Capabilities.Transactions,
 			LegacyProtocol41:           false,
 			SecureConnection:           true,
 			MultiStatements:            false,
 			MultiResults:               false,
 			PSMultiResults:             true,
-			PluginAuth:                 true,
+			PluginAuth:                 c.Handshake.Capabilities.PluginAuth,
 			ConnectAttrs:               false,
 			PluginAuthLenEncClientData: false,
 			CanHandleExpiredPasswords:  false,
-			SessionTrack:               true,
+			SessionTrack:               c.Handshake.Capabilities.SessionTrack,
 			DeprecateEOF:               false,
 			SSLVerifyServerCert:        c.Config.VerifyCert,
 			OptionalResultSetMetadata:  false,

+ 2 - 2
config.json

@@ -1,10 +1,10 @@
 {
   "host": "127.0.0.1",
-  "port": 3318,
+  "port": 3316,
   "user": "root",
   "password": "root",
   "database": "information_schema",
-  "ssl": true,
+  "ssl": false,
   "ssl-key": "",
   "ssl-cer": "",
   "ssl-ca": "/Users/josh/Sites/Certificates.pem",