|
@@ -3,12 +3,14 @@ package binlog
|
|
|
import (
|
|
import (
|
|
|
"bufio"
|
|
"bufio"
|
|
|
"bytes"
|
|
"bytes"
|
|
|
|
|
+ "crypto/tls"
|
|
|
"database/sql"
|
|
"database/sql"
|
|
|
"database/sql/driver"
|
|
"database/sql/driver"
|
|
|
"encoding/binary"
|
|
"encoding/binary"
|
|
|
"encoding/json"
|
|
"encoding/json"
|
|
|
"errors"
|
|
"errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
|
|
+ "io"
|
|
|
"io/ioutil"
|
|
"io/ioutil"
|
|
|
"math"
|
|
"math"
|
|
|
"net"
|
|
"net"
|
|
@@ -25,7 +27,7 @@ const TypeLenEncInt = int(3)
|
|
|
const TypeRestOfPacketString = int(4)
|
|
const TypeRestOfPacketString = int(4)
|
|
|
|
|
|
|
|
// Integer Maximums
|
|
// Integer Maximums
|
|
|
-const MaxUint8 = 1<<8 - 1
|
|
|
|
|
|
|
+const MaxUint08 = 1<<8 - 1
|
|
|
const MaxUint16 = 1<<16 - 1
|
|
const MaxUint16 = 1<<16 - 1
|
|
|
const MaxUint24 = 1<<24 - 1
|
|
const MaxUint24 = 1<<24 - 1
|
|
|
const MaxUint64 = 1<<64 - 1
|
|
const MaxUint64 = 1<<64 - 1
|
|
@@ -41,13 +43,16 @@ type Config struct {
|
|
|
Pass string `json:"password"`
|
|
Pass string `json:"password"`
|
|
|
Database string `json:"database"`
|
|
Database string `json:"database"`
|
|
|
SSL bool `json:"ssl"`
|
|
SSL bool `json:"ssl"`
|
|
|
- VerifyCert bool `json:"verify_cert"`
|
|
|
|
|
|
|
+ SSLCA string `json:"ssl-ca"`
|
|
|
|
|
+ SSLCer string `json:"ssl-cer"`
|
|
|
|
|
+ SSLKey string `json:"ssl-key"`
|
|
|
|
|
+ VerifyCert bool `json:"verify-cert"`
|
|
|
Timeout time.Duration
|
|
Timeout time.Duration
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func splitByBytesFunc(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
func splitByBytesFunc(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
if atEOF {
|
|
if atEOF {
|
|
|
- return 0, nil, errors.New("scanner found EOF")
|
|
|
|
|
|
|
+ return 0, nil, io.EOF
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return 1, data[:1], nil
|
|
return 1, data[:1], nil
|
|
@@ -69,7 +74,9 @@ func newBinlogConfig(dsn string) (*Config, error) {
|
|
|
|
|
|
|
|
type Conn struct {
|
|
type Conn struct {
|
|
|
Config *Config
|
|
Config *Config
|
|
|
|
|
+ curConn net.Conn
|
|
|
tcpConn *net.TCPConn
|
|
tcpConn *net.TCPConn
|
|
|
|
|
+ secTCPConn *tls.Conn
|
|
|
Handshake *Handshake
|
|
Handshake *Handshake
|
|
|
HandshakeResponse *HandshakeResponse
|
|
HandshakeResponse *HandshakeResponse
|
|
|
buffer *bufio.ReadWriter
|
|
buffer *bufio.ReadWriter
|
|
@@ -108,9 +115,10 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
|
|
|
|
|
|
|
|
c := newBinlogConn(config)
|
|
c := newBinlogConn(config)
|
|
|
|
|
|
|
|
|
|
+ var t interface{}
|
|
|
dialer := net.Dialer{Timeout: c.Config.Timeout}
|
|
dialer := net.Dialer{Timeout: c.Config.Timeout}
|
|
|
addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
|
|
addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
|
|
|
- t, err := dialer.Dial("tcp", addr)
|
|
|
|
|
|
|
+ t, err = dialer.Dial("tcp", addr)
|
|
|
|
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
netErr, ok := err.(net.Error)
|
|
netErr, ok := err.(net.Error)
|
|
@@ -120,6 +128,7 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
|
c.tcpConn = t.(*net.TCPConn)
|
|
c.tcpConn = t.(*net.TCPConn)
|
|
|
|
|
+ c.setConnection(t.(net.Conn))
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
err = c.decodeHandshakePacket()
|
|
err = c.decodeHandshakePacket()
|
|
@@ -127,6 +136,27 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ c.HandshakeResponse = c.NewHandshakeResponse()
|
|
|
|
|
+
|
|
|
|
|
+ // Send SSL_Request Packet
|
|
|
|
|
+ if c.Config.SSL {
|
|
|
|
|
+ err = c.writeSSLRequestPacket()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ tlsConf := NewClientTLSConfig(
|
|
|
|
|
+ c.Config.SSLKey,
|
|
|
|
|
+ c.Config.SSLCer,
|
|
|
|
|
+ []byte(c.Config.SSLCA),
|
|
|
|
|
+ c.Config.VerifyCert,
|
|
|
|
|
+ c.Config.Host,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ c.secTCPConn = tls.Client(c.tcpConn, tlsConf)
|
|
|
|
|
+ c.setConnection(c.secTCPConn)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
err = c.writeHandshakeResponse()
|
|
err = c.writeHandshakeResponse()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
@@ -138,24 +168,36 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) listen() error {
|
|
func (c *Conn) listen() error {
|
|
|
|
|
+ fmt.Println("LISTEN")
|
|
|
ph, err := c.getPacketHeader()
|
|
ph, err := c.getPacketHeader()
|
|
|
|
|
+
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
switch ph.Status {
|
|
switch ph.Status {
|
|
|
case 0x01:
|
|
case 0x01:
|
|
|
- p, err := c.decodeAuthMoreDataResponsePacket(ph)
|
|
|
|
|
|
|
+ fmt.Println("IN: AuthMoreDate PACKET")
|
|
|
|
|
+ _, err := c.decodeAuthMoreDataResponsePacket(ph)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
- fmt.Printf("%+v", p)
|
|
|
|
|
case 0x00:
|
|
case 0x00:
|
|
|
- fmt.Println("OK")
|
|
|
|
|
|
|
+ fmt.Println("IN: OK PACKET")
|
|
|
case 0xFE:
|
|
case 0xFE:
|
|
|
- fmt.Println("EOF")
|
|
|
|
|
|
|
+ fmt.Println("IN: EOF PACKET")
|
|
|
case 0xFF:
|
|
case 0xFF:
|
|
|
- fmt.Println("ERROR")
|
|
|
|
|
|
|
+ fmt.Println("IN: ERROR PACKET")
|
|
|
|
|
+ ep, err := c.decodeErrorPacket(ph)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ err = errors.New(fmt.Sprintf("Error %d: %s", ep.ErrorCode, ep.ErrorMessage))
|
|
|
|
|
+ return err
|
|
|
|
|
+ default:
|
|
|
|
|
+ fmt.Printf("ph = %+v\n", ph)
|
|
|
|
|
+ fmt.Println("IN: UNKNOWN PACKET")
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
err = c.scanner.Err()
|
|
err = c.scanner.Err()
|
|
@@ -196,16 +238,6 @@ func init() {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (c *Conn) readBytes(l uint64) *bytes.Buffer {
|
|
func (c *Conn) readBytes(l uint64) *bytes.Buffer {
|
|
|
- if c.buffer == nil {
|
|
|
|
|
- c.buffer = bufio.NewReadWriter(
|
|
|
|
|
- bufio.NewReader(c.tcpConn),
|
|
|
|
|
- bufio.NewWriter(c.tcpConn),
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- c.scanner = bufio.NewScanner(c.buffer.Reader)
|
|
|
|
|
- c.scanner.Split(splitByBytesFunc)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
b := make([]byte, 0)
|
|
b := make([]byte, 0)
|
|
|
for i := uint64(0); i < l; i++ {
|
|
for i := uint64(0); i < l; i++ {
|
|
|
c.scanner.Scan()
|
|
c.scanner.Scan()
|
|
@@ -215,6 +247,31 @@ func (c *Conn) readBytes(l uint64) *bytes.Buffer {
|
|
|
return bytes.NewBuffer(b)
|
|
return bytes.NewBuffer(b)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func (c *Conn) getBytesUntilEOF() *bytes.Buffer {
|
|
|
|
|
+ l := uint64(1)
|
|
|
|
|
+ s := c.readBytes(l)
|
|
|
|
|
+ b := s.Bytes()
|
|
|
|
|
+
|
|
|
|
|
+ for true {
|
|
|
|
|
+ if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
|
|
|
|
|
+ break
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ s = c.readBytes(uint64(l))
|
|
|
|
|
+
|
|
|
|
|
+ err := c.scanner.Err()
|
|
|
|
|
+ if err == io.EOF {
|
|
|
|
|
+ return bytes.NewBuffer(b)
|
|
|
|
|
+ } else if err != nil {
|
|
|
|
|
+ panic(err)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ b = append(b, s.Bytes()...)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return bytes.NewBuffer(b)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func (c *Conn) getBytesUntilNull() *bytes.Buffer {
|
|
func (c *Conn) getBytesUntilNull() *bytes.Buffer {
|
|
|
l := uint64(1)
|
|
l := uint64(1)
|
|
|
s := c.readBytes(l)
|
|
s := c.readBytes(l)
|
|
@@ -259,6 +316,8 @@ func (c *Conn) getString(t int, l uint64) string {
|
|
|
v = c.decFixedString(l)
|
|
v = c.decFixedString(l)
|
|
|
case TypeNullTerminatedString:
|
|
case TypeNullTerminatedString:
|
|
|
v = c.decNullTerminatedString()
|
|
v = c.decNullTerminatedString()
|
|
|
|
|
+ case TypeRestOfPacketString:
|
|
|
|
|
+ v = c.decRestOfPacketString()
|
|
|
default:
|
|
default:
|
|
|
v = ""
|
|
v = ""
|
|
|
}
|
|
}
|
|
@@ -266,6 +325,11 @@ func (c *Conn) getString(t int, l uint64) string {
|
|
|
return v
|
|
return v
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func (c *Conn) decRestOfPacketString() string {
|
|
|
|
|
+ b := c.getBytesUntilEOF()
|
|
|
|
|
+ return string(b.Bytes())
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func (c *Conn) decNullTerminatedString() string {
|
|
func (c *Conn) decNullTerminatedString() string {
|
|
|
b := c.getBytesUntilNull()
|
|
b := c.getBytesUntilNull()
|
|
|
return strings.TrimRight(b.String(), string(NullByte))
|
|
return strings.TrimRight(b.String(), string(NullByte))
|
|
@@ -322,11 +386,11 @@ func (c *Conn) encLenEncInt(v uint64) []byte {
|
|
|
prefix := make([]byte, 1)
|
|
prefix := make([]byte, 1)
|
|
|
var b []byte
|
|
var b []byte
|
|
|
switch {
|
|
switch {
|
|
|
- case v < MaxUint8:
|
|
|
|
|
|
|
+ case v < MaxUint08:
|
|
|
b = make([]byte, 2)
|
|
b = make([]byte, 2)
|
|
|
binary.LittleEndian.PutUint16(b, uint16(v))
|
|
binary.LittleEndian.PutUint16(b, uint16(v))
|
|
|
b = b[:1]
|
|
b = b[:1]
|
|
|
- case v >= MaxUint8 && v < MaxUint16:
|
|
|
|
|
|
|
+ case v >= MaxUint08 && v < MaxUint16:
|
|
|
prefix[0] = 0xFC
|
|
prefix[0] = 0xFC
|
|
|
b = make([]byte, 3)
|
|
b = make([]byte, 3)
|
|
|
binary.LittleEndian.PutUint16(b, uint16(v))
|
|
binary.LittleEndian.PutUint16(b, uint16(v))
|
|
@@ -494,10 +558,21 @@ func (c *Conn) Flush() error {
|
|
|
|
|
|
|
|
c.writeBuf = c.addHeader()
|
|
c.writeBuf = c.addHeader()
|
|
|
_, _ = c.buffer.Write(c.writeBuf.Bytes())
|
|
_, _ = c.buffer.Write(c.writeBuf.Bytes())
|
|
|
|
|
+
|
|
|
|
|
+ // log all packets
|
|
|
|
|
+ fmt.Printf(
|
|
|
|
|
+ "\nOUT:\n%08b\n%x\n%s\n\n",
|
|
|
|
|
+ c.writeBuf.Bytes(),
|
|
|
|
|
+ c.writeBuf.Bytes(),
|
|
|
|
|
+ c.writeBuf.Bytes(),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
if c.buffer.Flush() != nil {
|
|
if c.buffer.Flush() != nil {
|
|
|
return c.buffer.Flush()
|
|
return c.buffer.Flush()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ c.writeBuf = nil
|
|
|
|
|
+
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -517,3 +592,39 @@ func (c *Conn) setupWriteBuffer() {
|
|
|
c.writeBuf = bytes.NewBuffer(nil)
|
|
c.writeBuf = bytes.NewBuffer(nil)
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+type ErrorPacket struct {
|
|
|
|
|
+ PacketHeader
|
|
|
|
|
+ ErrorCode uint64
|
|
|
|
|
+ ErrorMessage string
|
|
|
|
|
+ SQLStateMarker string
|
|
|
|
|
+ SQLState string
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (c *Conn) decodeErrorPacket(ph PacketHeader) (*ErrorPacket, error) {
|
|
|
|
|
+ ep := ErrorPacket{}
|
|
|
|
|
+ ep.PacketHeader = ph
|
|
|
|
|
+ ep.ErrorCode = c.getInt(TypeFixedInt, 2)
|
|
|
|
|
+ ep.SQLStateMarker = c.getString(TypeFixedString, 1)
|
|
|
|
|
+ ep.SQLState = c.getString(TypeFixedString, 5)
|
|
|
|
|
+ ep.ErrorMessage = c.getString(TypeRestOfPacketString, 0)
|
|
|
|
|
+
|
|
|
|
|
+ err := c.scanner.Err()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return &ep, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (c *Conn) setConnection(nc net.Conn) {
|
|
|
|
|
+ c.curConn = nc
|
|
|
|
|
+
|
|
|
|
|
+ c.buffer = bufio.NewReadWriter(
|
|
|
|
|
+ bufio.NewReader(c.curConn),
|
|
|
|
|
+ bufio.NewWriter(c.curConn),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ c.scanner = bufio.NewScanner(c.buffer.Reader)
|
|
|
|
|
+ c.scanner.Split(splitByBytesFunc)
|
|
|
|
|
+}
|