Bläddra i källkod

Implemented TLS motherfuckers!

Josh Brickner 7 år sedan
förälder
incheckning
ea4bd702bc
4 ändrade filer med 209 tillägg och 29 borttagningar
  1. 132 21
      binlog/connection.go
  2. 71 7
      binlog/handshake.go
  3. 5 1
      config.json
  4. 1 0
      docker-compose.yaml

+ 132 - 21
binlog/connection.go

@@ -3,12 +3,14 @@ package binlog
 import (
 	"bufio"
 	"bytes"
+	"crypto/tls"
 	"database/sql"
 	"database/sql/driver"
 	"encoding/binary"
 	"encoding/json"
 	"errors"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"math"
 	"net"
@@ -25,7 +27,7 @@ const TypeLenEncInt = int(3)
 const TypeRestOfPacketString = int(4)
 
 // Integer Maximums
-const MaxUint8 = 1<<8 - 1
+const MaxUint08 = 1<<8 - 1
 const MaxUint16 = 1<<16 - 1
 const MaxUint24 = 1<<24 - 1
 const MaxUint64 = 1<<64 - 1
@@ -41,13 +43,16 @@ type Config struct {
 	Pass       string `json:"password"`
 	Database   string `json:"database"`
 	SSL        bool   `json:"ssl"`
-	VerifyCert bool   `json:"verify_cert"`
+	SSLCA      string `json:"ssl-ca"`
+	SSLCer     string `json:"ssl-cer"`
+	SSLKey     string `json:"ssl-key"`
+	VerifyCert bool   `json:"verify-cert"`
 	Timeout    time.Duration
 }
 
 func splitByBytesFunc(data []byte, atEOF bool) (advance int, token []byte, err error) {
 	if atEOF {
-		return 0, nil, errors.New("scanner found EOF")
+		return 0, nil, io.EOF
 	}
 
 	return 1, data[:1], nil
@@ -69,7 +74,9 @@ func newBinlogConfig(dsn string) (*Config, error) {
 
 type Conn struct {
 	Config            *Config
+	curConn           net.Conn
 	tcpConn           *net.TCPConn
+	secTCPConn        *tls.Conn
 	Handshake         *Handshake
 	HandshakeResponse *HandshakeResponse
 	buffer            *bufio.ReadWriter
@@ -108,9 +115,10 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 
 	c := newBinlogConn(config)
 
+	var t interface{}
 	dialer := net.Dialer{Timeout: c.Config.Timeout}
 	addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
-	t, err := dialer.Dial("tcp", addr)
+	t, err = dialer.Dial("tcp", addr)
 
 	if err != nil {
 		netErr, ok := err.(net.Error)
@@ -120,6 +128,7 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 		}
 	} else {
 		c.tcpConn = t.(*net.TCPConn)
+		c.setConnection(t.(net.Conn))
 	}
 
 	err = c.decodeHandshakePacket()
@@ -127,6 +136,27 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 		return nil, err
 	}
 
+	c.HandshakeResponse = c.NewHandshakeResponse()
+
+	// Send SSL_Request Packet
+	if c.Config.SSL {
+		err = c.writeSSLRequestPacket()
+		if err != nil {
+			return nil, err
+		}
+
+		tlsConf := NewClientTLSConfig(
+			c.Config.SSLKey,
+			c.Config.SSLCer,
+			[]byte(c.Config.SSLCA),
+			c.Config.VerifyCert,
+			c.Config.Host,
+		)
+
+		c.secTCPConn = tls.Client(c.tcpConn, tlsConf)
+		c.setConnection(c.secTCPConn)
+	}
+
 	err = c.writeHandshakeResponse()
 	if err != nil {
 		return nil, err
@@ -138,24 +168,36 @@ func (d Driver) Open(dsn string) (driver.Conn, error) {
 }
 
 func (c *Conn) listen() error {
+	fmt.Println("LISTEN")
 	ph, err := c.getPacketHeader()
+
 	if err != nil {
 		return err
 	}
 
 	switch ph.Status {
 	case 0x01:
-		p, err := c.decodeAuthMoreDataResponsePacket(ph)
+		fmt.Println("IN: AuthMoreDate PACKET")
+		_, err := c.decodeAuthMoreDataResponsePacket(ph)
 		if err != nil {
 			return err
 		}
-		fmt.Printf("%+v", p)
 	case 0x00:
-		fmt.Println("OK")
+		fmt.Println("IN: OK PACKET")
 	case 0xFE:
-		fmt.Println("EOF")
+		fmt.Println("IN: EOF PACKET")
 	case 0xFF:
-		fmt.Println("ERROR")
+		fmt.Println("IN: ERROR PACKET")
+		ep, err := c.decodeErrorPacket(ph)
+		if err != nil {
+			return err
+		}
+
+		err = errors.New(fmt.Sprintf("Error %d: %s", ep.ErrorCode, ep.ErrorMessage))
+		return err
+	default:
+		fmt.Printf("ph = %+v\n", ph)
+		fmt.Println("IN: UNKNOWN PACKET")
 	}
 
 	err = c.scanner.Err()
@@ -196,16 +238,6 @@ func init() {
 }
 
 func (c *Conn) readBytes(l uint64) *bytes.Buffer {
-	if c.buffer == nil {
-		c.buffer = bufio.NewReadWriter(
-			bufio.NewReader(c.tcpConn),
-			bufio.NewWriter(c.tcpConn),
-		)
-
-		c.scanner = bufio.NewScanner(c.buffer.Reader)
-		c.scanner.Split(splitByBytesFunc)
-	}
-
 	b := make([]byte, 0)
 	for i := uint64(0); i < l; i++ {
 		c.scanner.Scan()
@@ -215,6 +247,31 @@ func (c *Conn) readBytes(l uint64) *bytes.Buffer {
 	return bytes.NewBuffer(b)
 }
 
+func (c *Conn) getBytesUntilEOF() *bytes.Buffer {
+	l := uint64(1)
+	s := c.readBytes(l)
+	b := s.Bytes()
+
+	for true {
+		if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
+			break
+		}
+
+		s = c.readBytes(uint64(l))
+
+		err := c.scanner.Err()
+		if err == io.EOF {
+			return bytes.NewBuffer(b)
+		} else if err != nil {
+			panic(err)
+		}
+
+		b = append(b, s.Bytes()...)
+	}
+
+	return bytes.NewBuffer(b)
+}
+
 func (c *Conn) getBytesUntilNull() *bytes.Buffer {
 	l := uint64(1)
 	s := c.readBytes(l)
@@ -259,6 +316,8 @@ func (c *Conn) getString(t int, l uint64) string {
 		v = c.decFixedString(l)
 	case TypeNullTerminatedString:
 		v = c.decNullTerminatedString()
+	case TypeRestOfPacketString:
+		v = c.decRestOfPacketString()
 	default:
 		v = ""
 	}
@@ -266,6 +325,11 @@ func (c *Conn) getString(t int, l uint64) string {
 	return v
 }
 
+func (c *Conn) decRestOfPacketString() string {
+	b := c.getBytesUntilEOF()
+	return string(b.Bytes())
+}
+
 func (c *Conn) decNullTerminatedString() string {
 	b := c.getBytesUntilNull()
 	return strings.TrimRight(b.String(), string(NullByte))
@@ -322,11 +386,11 @@ func (c *Conn) encLenEncInt(v uint64) []byte {
 	prefix := make([]byte, 1)
 	var b []byte
 	switch {
-	case v < MaxUint8:
+	case v < MaxUint08:
 		b = make([]byte, 2)
 		binary.LittleEndian.PutUint16(b, uint16(v))
 		b = b[:1]
-	case v >= MaxUint8 && v < MaxUint16:
+	case v >= MaxUint08 && v < MaxUint16:
 		prefix[0] = 0xFC
 		b = make([]byte, 3)
 		binary.LittleEndian.PutUint16(b, uint16(v))
@@ -494,10 +558,21 @@ func (c *Conn) Flush() error {
 
 	c.writeBuf = c.addHeader()
 	_, _ = c.buffer.Write(c.writeBuf.Bytes())
+
+	// log all packets
+	fmt.Printf(
+		"\nOUT:\n%08b\n%x\n%s\n\n",
+		c.writeBuf.Bytes(),
+		c.writeBuf.Bytes(),
+		c.writeBuf.Bytes(),
+	)
+
 	if c.buffer.Flush() != nil {
 		return c.buffer.Flush()
 	}
 
+	c.writeBuf = nil
+
 	return nil
 }
 
@@ -517,3 +592,39 @@ func (c *Conn) setupWriteBuffer() {
 		c.writeBuf = bytes.NewBuffer(nil)
 	}
 }
+
+type ErrorPacket struct {
+	PacketHeader
+	ErrorCode      uint64
+	ErrorMessage   string
+	SQLStateMarker string
+	SQLState       string
+}
+
+func (c *Conn) decodeErrorPacket(ph PacketHeader) (*ErrorPacket, error) {
+	ep := ErrorPacket{}
+	ep.PacketHeader = ph
+	ep.ErrorCode = c.getInt(TypeFixedInt, 2)
+	ep.SQLStateMarker = c.getString(TypeFixedString, 1)
+	ep.SQLState = c.getString(TypeFixedString, 5)
+	ep.ErrorMessage = c.getString(TypeRestOfPacketString, 0)
+
+	err := c.scanner.Err()
+	if err != nil {
+		return nil, err
+	}
+
+	return &ep, nil
+}
+
+func (c *Conn) setConnection(nc net.Conn) {
+	c.curConn = nc
+
+	c.buffer = bufio.NewReadWriter(
+		bufio.NewReader(c.curConn),
+		bufio.NewWriter(c.curConn),
+	)
+
+	c.scanner = bufio.NewScanner(c.buffer.Reader)
+	c.scanner.Split(splitByBytesFunc)
+}

+ 71 - 7
binlog/handshake.go

@@ -2,7 +2,10 @@ package binlog
 
 import (
 	"bytes"
+	"crypto/tls"
+	"crypto/x509"
 	"fmt"
+	"io/ioutil"
 )
 
 type Capabilities struct {
@@ -83,6 +86,13 @@ type HandshakeResponse struct {
 	KeyValues          map[string]string
 }
 
+type SSLRequest struct {
+	ClientFlag    *Capabilities
+	MaxPacketSize uint64
+	CharacterSet  uint64
+	Username      string
+}
+
 func (c *Conn) decodeCapabilityFlags(hs *Handshake) {
 	var cfb = append(hs.CapabilityFlags1.Bytes(), hs.CapabilityFlags2.Bytes()...)
 	capabilities := c.bitmaskToStruct(cfb, hs.Capabilities).(Capabilities)
@@ -127,8 +137,7 @@ func (c *Conn) decodeHandshakePacket() error {
 }
 
 func (c *Conn) writeHandshakeResponse() error {
-	hr := c.NewHandshakeResponse()
-	c.HandshakeResponse = hr
+	hr := c.HandshakeResponse
 	cf := c.structToBitmask(hr.ClientFlag)
 	c.putBytes(cf)
 	c.putInt(TypeFixedInt, hr.MaxPacketSize, 4)
@@ -161,7 +170,20 @@ func (c *Conn) writeHandshakeResponse() error {
 		c.putNullBytes(1)
 	}
 
-	fmt.Printf("%+v\n", hr)
+	if c.Flush() != nil {
+		return c.Flush()
+	}
+
+	return nil
+}
+
+func (c *Conn) writeSSLRequestPacket() error {
+	sr := c.NewSSLRequest()
+	cf := c.structToBitmask(sr.ClientFlag)
+	c.putBytes(cf)
+	c.putInt(TypeFixedInt, sr.MaxPacketSize, 4)
+	c.putInt(TypeFixedInt, sr.CharacterSet, 1)
+	c.putNullBytes(23)
 
 	if c.Flush() != nil {
 		return c.Flush()
@@ -170,13 +192,22 @@ func (c *Conn) writeHandshakeResponse() error {
 	return nil
 }
 
+func (c *Conn) NewSSLRequest() *SSLRequest {
+	return &SSLRequest{
+		ClientFlag:    c.HandshakeResponse.ClientFlag,
+		MaxPacketSize: c.HandshakeResponse.MaxPacketSize,
+		CharacterSet:  c.HandshakeResponse.CharacterSet,
+		Username:      c.HandshakeResponse.Username,
+	}
+}
+
 func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
 	return &HandshakeResponse{
 		ClientFlag: &Capabilities{
 			LongPassword:               true,
 			FoundRows:                  true,
 			LongFlag:                   false,
-			ConnectWithDB:              true,
+			ConnectWithDB:              false,
 			NoSchema:                   false,
 			Compress:                   false,
 			ODBC:                       false,
@@ -184,7 +215,7 @@ func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
 			IgnoreSpace:                true,
 			Protocol41:                 true,
 			Interactive:                true,
-			SSL:                        false,
+			SSL:                        c.Config.SSL,
 			IgnoreSigpipe:              false,
 			Transactions:               true,
 			LegacyProtocol41:           false,
@@ -196,9 +227,9 @@ func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
 			ConnectAttrs:               false,
 			PluginAuthLenEncClientData: false,
 			CanHandleExpiredPasswords:  false,
-			SessionTrack:               false,
+			SessionTrack:               true,
 			DeprecateEOF:               false,
-			SSLVerifyServerCert:        false,
+			SSLVerifyServerCert:        c.Config.VerifyCert,
 			OptionalResultSetMetadata:  false,
 			RememberOptions:            false,
 		},
@@ -212,3 +243,36 @@ func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
 		KeyValues:          nil,
 	}
 }
+
+// generate TLS config for client side
+// if insecureSkipVerify is set to true, serverName will not be validated
+func NewClientTLSConfig(keyPem string, cerPem string, caPem []byte, insecureSkipVerify bool, serverName string) *tls.Config {
+	fmt.Printf("insecureSkipVerify = %+v\n", insecureSkipVerify)
+	config := &tls.Config{
+		InsecureSkipVerify: !insecureSkipVerify,
+		ServerName:         serverName,
+	}
+
+	if caPem != nil {
+		ca, err := ioutil.ReadFile(string(caPem))
+		if err == nil {
+			pool := x509.NewCertPool()
+			if !pool.AppendCertsFromPEM(ca) {
+				panic("failed to add ca PEM")
+			}
+
+			config.RootCAs = pool
+		}
+	}
+
+	if keyPem != "" && cerPem != "" {
+		cert, err := tls.LoadX509KeyPair(cerPem, keyPem)
+		if err != nil {
+			panic(err)
+		}
+
+		config.Certificates = []tls.Certificate{cert}
+	}
+
+	return config
+}

+ 5 - 1
config.json

@@ -4,5 +4,9 @@
   "user": "root",
   "password": "root",
   "database": "information_schema",
-  "ssl": false
+  "ssl": true,
+  "ssl-key": "",
+  "ssl-cer": "",
+  "ssl-ca": "/Users/josh/Sites/Certificates.pem",
+  "verify-cert": false
 }

+ 1 - 0
docker-compose.yaml

@@ -7,6 +7,7 @@ services:
       MYSQL_DATABASE: root
     ports:
       - "3318:3306"
+    command: "--ssl"
     # command: "--default-authentication-plugin=mysql_native_password"
     restart: always
   mysql57: