|
|
@@ -10,7 +10,6 @@ import (
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
- "io"
|
|
|
"io/ioutil"
|
|
|
"math"
|
|
|
"net"
|
|
|
@@ -50,14 +49,6 @@ type Config struct {
|
|
|
Timeout time.Duration
|
|
|
}
|
|
|
|
|
|
-func splitByBytesFunc(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
- if atEOF {
|
|
|
- return 0, nil, io.EOF
|
|
|
- }
|
|
|
-
|
|
|
- return 1, data[:1], nil
|
|
|
-}
|
|
|
-
|
|
|
func newBinlogConfig(dsn string) (*Config, error) {
|
|
|
var err error
|
|
|
|
|
|
@@ -138,7 +129,7 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
|
|
|
|
|
|
c.HandshakeResponse = c.NewHandshakeResponse()
|
|
|
|
|
|
- // Send SSL_Request Packet
|
|
|
+ // If we are on SSL send SSL_Request packet now
|
|
|
if c.Config.SSL {
|
|
|
err = c.writeSSLRequestPacket()
|
|
|
if err != nil {
|
|
|
@@ -168,9 +159,7 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
|
|
|
}
|
|
|
|
|
|
func (c *Conn) listen() error {
|
|
|
- fmt.Println("LISTEN")
|
|
|
ph, err := c.getPacketHeader()
|
|
|
-
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -178,10 +167,24 @@ func (c *Conn) listen() error {
|
|
|
switch ph.Status {
|
|
|
case 0x01:
|
|
|
fmt.Println("IN: AuthMoreDate PACKET")
|
|
|
- _, err := c.decodeAuthMoreDataResponsePacket(ph)
|
|
|
+ md, err := c.decodeAuthMoreDataResponsePacket(ph)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+
|
|
|
+ switch md.Data {
|
|
|
+ case SHA2_FAST_AUTH_SUCCESS:
|
|
|
+ fmt.Println("FAST AUTH")
|
|
|
+ case SHA2_REQUEST_PUBLIC_KEY:
|
|
|
+ fmt.Println("REQUEST PUBLIC KEY")
|
|
|
+ case SHA2_PERFORM_FULL_AUTHENTICATION:
|
|
|
+ fmt.Println("FULL AUTH")
|
|
|
+ c.putBytes(append([]byte(c.Config.Pass), NullByte))
|
|
|
+ if c.Flush() != nil {
|
|
|
+ return c.Flush()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
case 0x00:
|
|
|
fmt.Println("IN: OK PACKET")
|
|
|
case 0xFE:
|
|
|
@@ -195,9 +198,6 @@ func (c *Conn) listen() error {
|
|
|
|
|
|
err = errors.New(fmt.Sprintf("Error %d: %s", ep.ErrorCode, ep.ErrorMessage))
|
|
|
return err
|
|
|
- default:
|
|
|
- fmt.Printf("ph = %+v\n", ph)
|
|
|
- fmt.Println("IN: UNKNOWN PACKET")
|
|
|
}
|
|
|
|
|
|
err = c.scanner.Err()
|
|
|
@@ -240,7 +240,11 @@ func init() {
|
|
|
func (c *Conn) readBytes(l uint64) *bytes.Buffer {
|
|
|
b := make([]byte, 0)
|
|
|
for i := uint64(0); i < l; i++ {
|
|
|
- c.scanner.Scan()
|
|
|
+ didScan := c.scanner.Scan()
|
|
|
+ if !didScan {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
b = append(b, c.scanner.Bytes()...)
|
|
|
}
|
|
|
|
|
|
@@ -257,13 +261,9 @@ func (c *Conn) getBytesUntilEOF() *bytes.Buffer {
|
|
|
break
|
|
|
}
|
|
|
|
|
|
- s = c.readBytes(uint64(l))
|
|
|
-
|
|
|
- err := c.scanner.Err()
|
|
|
- if err == io.EOF {
|
|
|
+ s := c.readBytes(uint64(l))
|
|
|
+ if s == nil {
|
|
|
return bytes.NewBuffer(b)
|
|
|
- } else if err != nil {
|
|
|
- panic(err)
|
|
|
}
|
|
|
|
|
|
b = append(b, s.Bytes()...)
|
|
|
@@ -556,10 +556,11 @@ func (c *Conn) Flush() error {
|
|
|
return c.err
|
|
|
}
|
|
|
|
|
|
+ fmt.Println(string(c.writeBuf.Bytes()))
|
|
|
c.writeBuf = c.addHeader()
|
|
|
_, _ = c.buffer.Write(c.writeBuf.Bytes())
|
|
|
|
|
|
- // log all packets
|
|
|
+ // log all outgoing packets
|
|
|
fmt.Printf(
|
|
|
"\nOUT:\n%08b\n%x\n%s\n\n",
|
|
|
c.writeBuf.Bytes(),
|
|
|
@@ -626,5 +627,5 @@ func (c *Conn) setConnection(nc net.Conn) {
|
|
|
)
|
|
|
|
|
|
c.scanner = bufio.NewScanner(c.buffer.Reader)
|
|
|
- c.scanner.Split(splitByBytesFunc)
|
|
|
+ c.scanner.Split(bufio.ScanBytes)
|
|
|
}
|