Bladeren bron

Code cleanup

Josh Brickner 7 jaren geleden
bovenliggende
commit
a3a3c3cac9
3 gewijzigde bestanden met toevoegingen van 148 en 99 verwijderingen
  1. 1 0
      .gitignore
  2. 42 5
      binlog/connection.go
  3. 105 94
      binlog/handshake.go

+ 1 - 0
.gitignore

@@ -1,3 +1,4 @@
 tags
 .idea
 .DS_Store
+mysql_binlog_filter

+ 42 - 5
binlog/connection.go

@@ -20,6 +20,11 @@ const TypeFixedString = int(1)
 const TypeFixedInt = int(2)
 const TypeLenEncodedInt = int(3)
 
+const MaxUint8 = 1<<8 - 1
+const MaxUint16 = 1<<16 - 1
+const MaxUint24 = 1<<24 - 1
+const MaxUint64 = 1<<64 - 1
+
 type Config struct {
 	Host       string `json:"host"`
 	Port       int    `json:"port"`
@@ -87,12 +92,10 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 		}
 	}
 
-	hsp, err := blConn.handshakePacket()
-	blConn.Handshake = hsp
+	err = blConn.decodeHandshakePacket()
+	b := blConn.encodeHandshakeResponse()
+	fmt.Printf("%08b\n%d\n%s", b, b, b)
 
-	resp := blConn.handshakeResponse()
-	b := resp.encode(&blConn)
-	fmt.Printf("%d", b)
 	_, err = blConn.tcpConn.Write(b)
 
 	return blConn, err
@@ -213,3 +216,37 @@ func (c *Conn) popFixedInt(l uint64) (uint64, error) {
 
 	return i, err
 }
+
+func (c *Conn) encFixedLenInt(l uint64, v uint64) []byte {
+	b := make([]byte, 4)
+	binary.LittleEndian.PutUint64(b, v)
+	return b[:(l - 1)]
+}
+
+func (c *Conn) encLenEncInt(v uint64) []byte {
+	prefix := make([]byte, 1)
+	var b []byte
+	switch {
+	case v < MaxUint8:
+		b = make([]byte, 2)
+		binary.LittleEndian.PutUint16(b, uint16(v))
+		b = b[:1]
+	case v >= MaxUint8 && v < MaxUint16:
+		prefix[0] = 0xFC
+		b = make([]byte, 3)
+		binary.LittleEndian.PutUint16(b, uint16(v))
+		b = b[:2]
+	case v >= MaxUint16 && v < MaxUint24:
+		prefix[0] = 0xFD
+		b = make([]byte, 4)
+		binary.LittleEndian.PutUint32(b, uint32(v))
+		b = b[:3]
+	case v >= MaxUint24 && v < MaxUint64:
+		prefix[0] = 0xFE
+		b = make([]byte, 9)
+		binary.LittleEndian.PutUint64(b, uint64(v))
+	}
+
+	b = append(prefix, b...)
+	return b
+}

+ 105 - 94
binlog/handshake.go

@@ -5,24 +5,6 @@ import (
 	"encoding/binary"
 )
 
-type HandshakePacket struct {
-	PacketLength         uint64
-	SequenceId           uint64
-	ProtocolVersion      uint64
-	ServerVersion        string
-	ThreadId             uint64
-	AuthPluginDataPart1  []byte
-	CapabilityFlags1     []byte
-	Charset              uint64
-	Status               []byte
-	CapabilityFlags2     []byte
-	AuthPluginDataLength uint64
-	AuthPluginDataPart2  []byte
-	AuthPluginName       string
-	CapabilityFlags      *CapabilityFlags
-	StatusFlags          *StatusFlags
-}
-
 type CapabilityFlags struct {
 	LongPassword               bool
 	FoundRows                  bool
@@ -30,26 +12,26 @@ type CapabilityFlags struct {
 	ConnectWithDb              bool
 	NoSchema                   bool
 	Compress                   bool
-	Odbc                       bool
+	ODBC                       bool
 	LocalFiles                 bool
 	IgnoreSpace                bool
 	Protocol41                 bool
 	Interactive                bool
-	Ssl                        bool
+	SSL                        bool
 	IgnoreSigpipe              bool
 	Transactions               bool
-	Reserved                   bool
-	Reserved2                  bool
+	LegacyProtocol41           bool
+	SecureConnection           bool
 	MultiStatements            bool
 	MultiResults               bool
-	PsMultiResults             bool
+	PSMultiResults             bool
 	PluginAuth                 bool
 	ConnectAttrs               bool
 	PluginAuthLenEncClientData bool
 	CanHandleExpiredPasswords  bool
 	SessionTrack               bool
 	DeprecateEOF               bool
-	SslVerifyServerCert        bool
+	SSLVerifyServerCert        bool
 	OptionalResultSetMetadata  bool
 	RememberOptions            bool
 }
@@ -62,16 +44,46 @@ type StatusFlags struct {
 	QueryNoIndexUsed         bool
 	StatusCursorExists       bool
 	StatusLastRowSent        bool
-	StatusDbDropped          bool
+	StatusDBDropped          bool
 	StatusNoBackslashEscapes bool
 	StatusMetadataChanged    bool
 	QueryWasSlow             bool
-	PsOutParams              bool
+	PSOutParams              bool
 	StatusInTransReadonly    bool
 	SessionStateChanged      bool
 }
 
-func (hs *HandshakePacket) decodeCapabilityFlags() {
+type HandshakePacket struct {
+	PacketLength         uint64
+	SequenceID           uint64
+	ProtocolVersion      uint64
+	ServerVersion        string
+	ThreadID             uint64
+	AuthPluginDataPart1  []byte
+	CapabilityFlags1     []byte
+	Charset              uint64
+	Status               []byte
+	CapabilityFlags2     []byte
+	AuthPluginDataLength uint64
+	AuthPluginDataPart2  []byte
+	AuthPluginName       string
+	CapabilityFlags      *CapabilityFlags
+	StatusFlags          *StatusFlags
+}
+
+type HandshakeResponse struct {
+	ClientFlag         *CapabilityFlags
+	MaxPacketSize      uint64
+	CharacterSet       uint64
+	Username           string
+	AuthResponseLength uint64
+	AuthResponse       string
+	Database           string
+	ClientPluginName   string
+	KeyValues          map[string]string
+}
+
+func (c *Conn) decodeCapabilityFlags(hs *HandshakePacket) {
 	hs.CapabilityFlags = &CapabilityFlags{
 		LongPassword:               (hs.CapabilityFlags1[0] & 1) > 0,
 		FoundRows:                  (hs.CapabilityFlags1[0] & 2) > 0,
@@ -79,32 +91,32 @@ func (hs *HandshakePacket) decodeCapabilityFlags() {
 		ConnectWithDb:              (hs.CapabilityFlags1[0] & 8) > 0,
 		NoSchema:                   (hs.CapabilityFlags1[0] & 16) > 0,
 		Compress:                   (hs.CapabilityFlags1[0] & 32) > 0,
-		Odbc:                       (hs.CapabilityFlags1[0] & 64) > 0,
+		ODBC:                       (hs.CapabilityFlags1[0] & 64) > 0,
 		LocalFiles:                 (hs.CapabilityFlags1[0] & 128) > 0,
 		IgnoreSpace:                (hs.CapabilityFlags1[1] & 1) > 0,
 		Protocol41:                 (hs.CapabilityFlags1[1] & 2) > 0,
 		Interactive:                (hs.CapabilityFlags1[1] & 4) > 0,
-		Ssl:                        (hs.CapabilityFlags1[1] & 8) > 0,
+		SSL:                        (hs.CapabilityFlags1[1] & 8) > 0,
 		IgnoreSigpipe:              (hs.CapabilityFlags1[1] & 16) > 0,
 		Transactions:               (hs.CapabilityFlags1[1] & 32) > 0,
-		Reserved:                   (hs.CapabilityFlags1[1] & 64) > 0,
-		Reserved2:                  (hs.CapabilityFlags1[1] & 128) > 0,
+		LegacyProtocol41:           (hs.CapabilityFlags1[1] & 64) > 0,
+		SecureConnection:           (hs.CapabilityFlags1[1] & 128) > 0,
 		MultiStatements:            (hs.CapabilityFlags2[0] & 1) > 0,
 		MultiResults:               (hs.CapabilityFlags2[0] & 2) > 0,
-		PsMultiResults:             (hs.CapabilityFlags2[0] & 4) > 0,
+		PSMultiResults:             (hs.CapabilityFlags2[0] & 4) > 0,
 		PluginAuth:                 (hs.CapabilityFlags2[0] & 8) > 0,
 		ConnectAttrs:               (hs.CapabilityFlags2[0] & 16) > 0,
 		PluginAuthLenEncClientData: (hs.CapabilityFlags2[0] & 32) > 0,
 		CanHandleExpiredPasswords:  (hs.CapabilityFlags2[0] & 64) > 0,
 		SessionTrack:               (hs.CapabilityFlags2[0] & 128) > 0,
 		DeprecateEOF:               (hs.CapabilityFlags2[1] & 1) > 0,
-		SslVerifyServerCert:        (hs.CapabilityFlags2[1] & 2) > 0,
+		SSLVerifyServerCert:        (hs.CapabilityFlags2[1] & 2) > 0,
 		OptionalResultSetMetadata:  (hs.CapabilityFlags2[1] & 4) > 0,
 		RememberOptions:            (hs.CapabilityFlags2[1] & 8) > 0,
 	}
 }
 
-func (hs *HandshakePacket) decodeStatusFlags() {
+func (c *Conn) decodeStatusFlags(hs *HandshakePacket) {
 	hs.StatusFlags = &StatusFlags{
 		StatusInTrans:            (hs.Status[0] & 1) > 0,
 		StatusAutocommit:         (hs.Status[0] & 2) > 0,
@@ -113,115 +125,106 @@ func (hs *HandshakePacket) decodeStatusFlags() {
 		QueryNoIndexUsed:         (hs.Status[0] & 16) > 0,
 		StatusCursorExists:       (hs.Status[0] & 32) > 0,
 		StatusLastRowSent:        (hs.Status[0] & 64) > 0,
-		StatusDbDropped:          (hs.Status[0] & 128) > 0,
+		StatusDBDropped:          (hs.Status[0] & 128) > 0,
 		StatusNoBackslashEscapes: (hs.Status[1] & 1) > 0,
 		StatusMetadataChanged:    (hs.Status[1] & 2) > 0,
 		QueryWasSlow:             (hs.Status[1] & 4) > 0,
-		PsOutParams:              (hs.Status[1] & 8) > 0,
+		PSOutParams:              (hs.Status[1] & 8) > 0,
 		StatusInTransReadonly:    (hs.Status[1] & 16) > 0,
 		SessionStateChanged:      (hs.Status[1] & 32) > 0,
 	}
 }
 
-func (c *Conn) handshakePacket() (*HandshakePacket, error) {
+func (c *Conn) decodeHandshakePacket() error {
 	packet := HandshakePacket{}
 	var err error
 
 	packet.PacketLength, err = c.getInt(TypeFixedInt, 3)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
-	packet.SequenceId, err = c.getInt(TypeFixedInt, 1)
+	packet.SequenceID, err = c.getInt(TypeFixedInt, 1)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	packet.ProtocolVersion, err = c.getInt(TypeFixedInt, 1)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	packet.ServerVersion, err = c.getString(TypeNullTerminatedString, 0)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
-	packet.ThreadId, err = c.getInt(TypeFixedInt, 4)
+	packet.ThreadID, err = c.getInt(TypeFixedInt, 4)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	packet.AuthPluginDataPart1, err = c.getBytes(8)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	err = c.consumeBytes(1)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	packet.CapabilityFlags1, err = c.getBytes(2)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	packet.Charset, err = c.getInt(TypeFixedInt, 1)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	packet.Status, err = c.getBytes(2)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
-	packet.decodeStatusFlags()
+	c.decodeStatusFlags(&packet)
 
 	packet.CapabilityFlags2, err = c.getBytes(2)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
-	packet.decodeCapabilityFlags()
+	c.decodeCapabilityFlags(&packet)
 
 	packet.AuthPluginDataLength, err = c.getInt(TypeFixedInt, 1)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	err = c.consumeBytes(10)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	packet.AuthPluginDataPart2, err = c.getBytes(packet.AuthPluginDataLength - 8)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
 	packet.AuthPluginName, err = c.getString(TypeNullTerminatedString, 0)
 	if err != nil {
-		return nil, err
+		return err
 	}
 
-	return &packet, nil
-}
+	c.Handshake = &packet
 
-type HandshakeResponse struct {
-	ClientFlag         *CapabilityFlags
-	MaxPacketSize      uint64
-	CharacterSet       uint64
-	Username           string
-	AuthResponseLength uint64
-	AuthResponse       string
-	Database           string
-	ClientPluginName   string
-	KeyValues          map[string]string
+	return nil
 }
 
-func (hr *HandshakeResponse) encode(c *Conn) []byte {
+func (c *Conn) encodeHandshakeResponse() []byte {
+	hr := NewHandshakeResponse()
 	buf := bytes.NewBuffer(make([]byte, 0))
 
 	// Capabilities flag.
@@ -250,7 +253,7 @@ func (hr *HandshakeResponse) encode(c *Conn) []byte {
 		flags[0] |= 0x20
 	}
 
-	if hr.ClientFlag.Odbc {
+	if hr.ClientFlag.ODBC {
 		flags[0] |= 0x40
 	}
 
@@ -270,7 +273,7 @@ func (hr *HandshakeResponse) encode(c *Conn) []byte {
 		flags[1] |= 0x4
 	}
 
-	if hr.ClientFlag.Ssl {
+	if hr.ClientFlag.SSL {
 		flags[1] |= 0x8
 	}
 
@@ -282,11 +285,11 @@ func (hr *HandshakeResponse) encode(c *Conn) []byte {
 		flags[1] |= 0x20
 	}
 
-	if hr.ClientFlag.Reserved {
+	if hr.ClientFlag.LegacyProtocol41 {
 		flags[1] |= 0x40
 	}
 
-	if hr.ClientFlag.Reserved2 {
+	if hr.ClientFlag.SecureConnection {
 		flags[1] |= 0x80
 	}
 
@@ -298,7 +301,7 @@ func (hr *HandshakeResponse) encode(c *Conn) []byte {
 		flags[2] |= 0x2
 	}
 
-	if hr.ClientFlag.PsMultiResults {
+	if hr.ClientFlag.PSMultiResults {
 		flags[2] |= 0x4
 	}
 
@@ -326,7 +329,7 @@ func (hr *HandshakeResponse) encode(c *Conn) []byte {
 		flags[3] |= 0x1
 	}
 
-	if hr.ClientFlag.SslVerifyServerCert {
+	if hr.ClientFlag.SSLVerifyServerCert {
 		flags[3] |= 0x2
 	}
 
@@ -342,9 +345,7 @@ func (hr *HandshakeResponse) encode(c *Conn) []byte {
 	buf.Write(flags)
 
 	// Write MaxPacketSize
-	mps := make([]byte, 4)
-	binary.LittleEndian.PutUint32(mps, uint32(MaxPacketSize))
-	buf.Write(mps)
+	//buf.Write()
 
 	// Write CharacterSet
 	cs := make([]byte, 2)
@@ -358,8 +359,18 @@ func (hr *HandshakeResponse) encode(c *Conn) []byte {
 	u := append([]byte(hr.Username), NullByte)
 	buf.Write(u)
 
+	salt := append(c.Handshake.AuthPluginDataPart1, c.Handshake.AuthPluginDataPart2...)
+	ar := c.cachingSha2Auth(salt, []byte(hr.AuthResponse))
 	if hr.ClientFlag.PluginAuthLenEncClientData {
-
+		buf.Write(c.encLenEncInt(uint64(len(ar))))
+		buf.Write(ar)
+	} else if hr.ClientFlag.SecureConnection {
+		l := make([]byte, 2)
+		binary.LittleEndian.PutUint16(l, uint16(len(ar)))
+		buf.Write(l[:1])
+		buf.Write(ar)
+	} else {
+		buf.Write(append(ar, NullByte))
 	}
 
 	// Write database name
@@ -381,7 +392,7 @@ func (hr *HandshakeResponse) encode(c *Conn) []byte {
 	return buf.Bytes()
 }
 
-func (c *Conn) handshakeResponse() *HandshakeResponse {
+func NewHandshakeResponse() *HandshakeResponse {
 	return &HandshakeResponse{
 		ClientFlag: &CapabilityFlags{
 			LongPassword:               true,
@@ -389,37 +400,37 @@ func (c *Conn) handshakeResponse() *HandshakeResponse {
 			LongFlag:                   true,
 			ConnectWithDb:              true,
 			NoSchema:                   false,
-			Compress:                   true,
-			Odbc:                       false,
+			Compress:                   false,
+			ODBC:                       false,
 			LocalFiles:                 false,
 			IgnoreSpace:                true,
 			Protocol41:                 true,
 			Interactive:                true,
-			Ssl:                        c.Config.SSL,
+			SSL:                        false,
 			IgnoreSigpipe:              false,
 			Transactions:               true,
-			Reserved:                   false,
-			Reserved2:                  false,
+			LegacyProtocol41:           false,
+			SecureConnection:           true,
 			MultiStatements:            false,
 			MultiResults:               false,
-			PsMultiResults:             true,
-			PluginAuth:                 true,
+			PSMultiResults:             true,
+			PluginAuth:                 false,
 			ConnectAttrs:               false,
-			PluginAuthLenEncClientData: true,
+			PluginAuthLenEncClientData: false,
 			CanHandleExpiredPasswords:  false,
 			SessionTrack:               true,
 			DeprecateEOF:               true,
-			SslVerifyServerCert:        c.Config.VerifyCert,
+			SSLVerifyServerCert:        false,
 			OptionalResultSetMetadata:  true,
 			RememberOptions:            true,
 		},
 		MaxPacketSize:      MaxPacketSize,
 		CharacterSet:       45,
-		Username:           c.Config.User,
-		AuthResponseLength: 4,
-		AuthResponse:       "root",
-		Database:           c.Config.Database,
-		ClientPluginName:   c.Handshake.AuthPluginName,
+		Username:           "",
+		AuthResponseLength: 0,
+		AuthResponse:       "",
+		Database:           "",
+		ClientPluginName:   "",
 		KeyValues:          nil,
 	}
 }