connection.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. package binlog
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "database/sql"
  7. "database/sql/driver"
  8. "encoding/binary"
  9. "encoding/json"
  10. "errors"
  11. "fmt"
  12. "io/ioutil"
  13. "math"
  14. "net"
  15. "reflect"
  16. "strings"
  17. "time"
  18. )
  19. // MySQL Packet Data Types
  20. const TypeNullTerminatedString = int(0)
  21. const TypeFixedString = int(1)
  22. const TypeFixedInt = int(2)
  23. const TypeLenEncInt = int(3)
  24. const TypeRestOfPacketString = int(4)
  25. // Integer Maximums
  26. const MaxUint08 = 1<<8 - 1
  27. const MaxUint16 = 1<<16 - 1
  28. const MaxUint24 = 1<<24 - 1
  29. const MaxUint64 = 1<<64 - 1
  30. // Misc. Constants
  31. const NullByte byte = 0
  32. const MaxPacketSize = MaxUint16
  33. type Config struct {
  34. Host string `json:"host"`
  35. Port int `json:"port"`
  36. User string `json:"user"`
  37. Pass string `json:"password"`
  38. Database string `json:"database"`
  39. SSL bool `json:"ssl"`
  40. SSLCA string `json:"ssl-ca"`
  41. SSLCer string `json:"ssl-cer"`
  42. SSLKey string `json:"ssl-key"`
  43. VerifyCert bool `json:"verify-cert"`
  44. Timeout time.Duration
  45. }
  46. func newBinlogConfig(dsn string) (*Config, error) {
  47. var err error
  48. b, err := ioutil.ReadFile(dsn)
  49. if err != nil {
  50. return nil, err
  51. }
  52. config := Config{}
  53. err = json.Unmarshal(b, &config)
  54. return &config, err
  55. }
  56. type Conn struct {
  57. Config *Config
  58. curConn net.Conn
  59. tcpConn *net.TCPConn
  60. secTCPConn *tls.Conn
  61. Handshake *Handshake
  62. HandshakeResponse *HandshakeResponse
  63. buffer *bufio.ReadWriter
  64. scanner *bufio.Scanner
  65. err error
  66. sequenceId uint64
  67. writeBuf *bytes.Buffer
  68. }
  69. func newBinlogConn(config *Config) Conn {
  70. return Conn{
  71. Config: config,
  72. sequenceId: 1,
  73. }
  74. }
  75. func (c Conn) Prepare(query string) (driver.Stmt, error) {
  76. return nil, nil
  77. }
  78. func (c Conn) Close() error {
  79. return nil
  80. }
  81. func (c Conn) Begin() (driver.Tx, error) {
  82. return nil, nil
  83. }
  84. type Driver struct{}
  85. func (d Driver) Open(dsn string) (driver.Conn, error) {
  86. config, err := newBinlogConfig(dsn)
  87. if nil != err {
  88. return nil, err
  89. }
  90. c := newBinlogConn(config)
  91. var t interface{}
  92. dialer := net.Dialer{Timeout: c.Config.Timeout}
  93. addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
  94. t, err = dialer.Dial("tcp", addr)
  95. if err != nil {
  96. netErr, ok := err.(net.Error)
  97. if ok && !netErr.Temporary() {
  98. fmt.Printf("Error: %s", netErr.Error())
  99. return nil, err
  100. }
  101. } else {
  102. c.tcpConn = t.(*net.TCPConn)
  103. c.setConnection(t.(net.Conn))
  104. }
  105. err = c.decodeHandshakePacket()
  106. if err != nil {
  107. return nil, err
  108. }
  109. c.HandshakeResponse = c.NewHandshakeResponse()
  110. // If we are on SSL send SSL_Request packet now
  111. if c.Config.SSL {
  112. err = c.writeSSLRequestPacket()
  113. if err != nil {
  114. return nil, err
  115. }
  116. tlsConf := NewClientTLSConfig(
  117. c.Config.SSLKey,
  118. c.Config.SSLCer,
  119. []byte(c.Config.SSLCA),
  120. c.Config.VerifyCert,
  121. c.Config.Host,
  122. )
  123. c.secTCPConn = tls.Client(c.tcpConn, tlsConf)
  124. c.setConnection(c.secTCPConn)
  125. }
  126. err = c.writeHandshakeResponse()
  127. if err != nil {
  128. return nil, err
  129. }
  130. err = c.listen()
  131. return c, err
  132. }
  133. func (c *Conn) listen() error {
  134. ph, err := c.getPacketHeader()
  135. if err != nil {
  136. return err
  137. }
  138. switch ph.Status {
  139. case 0x01:
  140. fmt.Println("IN: AuthMoreDate PACKET")
  141. md, err := c.decodeAuthMoreDataResponsePacket(ph)
  142. if err != nil {
  143. return err
  144. }
  145. switch md.Data {
  146. case SHA2_FAST_AUTH_SUCCESS:
  147. fmt.Println("FAST AUTH")
  148. case SHA2_REQUEST_PUBLIC_KEY:
  149. fmt.Println("REQUEST PUBLIC KEY")
  150. case SHA2_PERFORM_FULL_AUTHENTICATION:
  151. fmt.Println("FULL AUTH")
  152. c.putBytes(append([]byte(c.Config.Pass), NullByte))
  153. if c.Flush() != nil {
  154. return c.Flush()
  155. }
  156. }
  157. case 0x00:
  158. fmt.Println("IN: OK PACKET")
  159. case 0xFE:
  160. fmt.Println("IN: EOF PACKET")
  161. case 0xFF:
  162. fmt.Println("IN: ERROR PACKET")
  163. ep, err := c.decodeErrorPacket(ph)
  164. if err != nil {
  165. return err
  166. }
  167. err = errors.New(fmt.Sprintf("Error %d: %s", ep.ErrorCode, ep.ErrorMessage))
  168. return err
  169. }
  170. err = c.scanner.Err()
  171. if err != nil {
  172. return err
  173. }
  174. err = c.listen() // Listen forever until we get an error.
  175. if err != nil {
  176. return err
  177. }
  178. return nil
  179. }
  180. type PacketHeader struct {
  181. Length uint64
  182. SequenceID uint64
  183. Status uint64
  184. }
  185. func (c *Conn) getPacketHeader() (PacketHeader, error) {
  186. ph := PacketHeader{}
  187. ph.Length = c.getInt(TypeFixedInt, 3)
  188. ph.SequenceID = c.getInt(TypeFixedInt, 1)
  189. ph.Status = c.getInt(TypeFixedInt, 1)
  190. err := c.scanner.Err()
  191. if err != nil {
  192. return ph, err
  193. }
  194. return ph, nil
  195. }
  196. func init() {
  197. sql.Register("mysql-binlog", &Driver{})
  198. }
  199. func (c *Conn) readBytes(l uint64) *bytes.Buffer {
  200. b := make([]byte, 0)
  201. for i := uint64(0); i < l; i++ {
  202. didScan := c.scanner.Scan()
  203. if !didScan {
  204. return nil
  205. }
  206. b = append(b, c.scanner.Bytes()...)
  207. }
  208. return bytes.NewBuffer(b)
  209. }
  210. func (c *Conn) getBytesUntilEOF() *bytes.Buffer {
  211. l := uint64(1)
  212. s := c.readBytes(l)
  213. b := s.Bytes()
  214. for true {
  215. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  216. break
  217. }
  218. s := c.readBytes(uint64(l))
  219. if s == nil {
  220. return bytes.NewBuffer(b)
  221. }
  222. b = append(b, s.Bytes()...)
  223. }
  224. return bytes.NewBuffer(b)
  225. }
  226. func (c *Conn) getBytesUntilNull() *bytes.Buffer {
  227. l := uint64(1)
  228. s := c.readBytes(l)
  229. b := s.Bytes()
  230. for true {
  231. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  232. break
  233. }
  234. s = c.readBytes(uint64(l))
  235. b = append(b, s.Bytes()...)
  236. }
  237. return bytes.NewBuffer(b)
  238. }
  239. func (c *Conn) discardBytes(l int) {
  240. for i := 0; i < l; i++ {
  241. c.scanner.Scan()
  242. }
  243. }
  244. func (c *Conn) getInt(t int, l uint64) uint64 {
  245. var v uint64
  246. switch t {
  247. case TypeFixedInt:
  248. v = c.decFixedInt(l)
  249. default:
  250. v = 0
  251. }
  252. return v
  253. }
  254. func (c *Conn) getString(t int, l uint64) string {
  255. var v string
  256. switch t {
  257. case TypeFixedString:
  258. v = c.decFixedString(l)
  259. case TypeNullTerminatedString:
  260. v = c.decNullTerminatedString()
  261. case TypeRestOfPacketString:
  262. v = c.decRestOfPacketString()
  263. default:
  264. v = ""
  265. }
  266. return v
  267. }
  268. func (c *Conn) decRestOfPacketString() string {
  269. b := c.getBytesUntilEOF()
  270. return string(b.Bytes())
  271. }
  272. func (c *Conn) decNullTerminatedString() string {
  273. b := c.getBytesUntilNull()
  274. return strings.TrimRight(b.String(), string(NullByte))
  275. }
  276. func (c *Conn) decFixedString(l uint64) string {
  277. b := c.readBytes(l)
  278. return b.String()
  279. }
  280. func (c *Conn) decFixedInt(l uint64) uint64 {
  281. var i uint64
  282. b := c.readBytes(l)
  283. if l <= 2 {
  284. var x uint16
  285. pb := c.padBytes(2, b.Bytes())
  286. br := bytes.NewReader(pb)
  287. _ = binary.Read(br, binary.LittleEndian, &x)
  288. i = uint64(x)
  289. } else if l <= 4 {
  290. var x uint32
  291. pb := c.padBytes(4, b.Bytes())
  292. br := bytes.NewReader(pb)
  293. _ = binary.Read(br, binary.LittleEndian, &x)
  294. i = uint64(x)
  295. } else if l <= 8 {
  296. var x uint64
  297. pb := c.padBytes(8, b.Bytes())
  298. br := bytes.NewReader(pb)
  299. _ = binary.Read(br, binary.LittleEndian, &x)
  300. i = x
  301. }
  302. return i
  303. }
  304. func (c *Conn) padBytes(l int, b []byte) []byte {
  305. bl := len(b)
  306. pl := l - bl
  307. for i := 0; i < pl; i++ {
  308. b = append(b, NullByte)
  309. }
  310. return b
  311. }
  312. func (c *Conn) encFixedLenInt(v uint64, l uint64) []byte {
  313. b := make([]byte, 8)
  314. binary.LittleEndian.PutUint64(b, v)
  315. return b[:l]
  316. }
  317. func (c *Conn) encLenEncInt(v uint64) []byte {
  318. prefix := make([]byte, 1)
  319. var b []byte
  320. switch {
  321. case v < MaxUint08:
  322. b = make([]byte, 2)
  323. binary.LittleEndian.PutUint16(b, uint16(v))
  324. b = b[:1]
  325. case v >= MaxUint08 && v < MaxUint16:
  326. prefix[0] = 0xFC
  327. b = make([]byte, 3)
  328. binary.LittleEndian.PutUint16(b, uint16(v))
  329. b = b[:2]
  330. case v >= MaxUint16 && v < MaxUint24:
  331. prefix[0] = 0xFD
  332. b = make([]byte, 4)
  333. binary.LittleEndian.PutUint32(b, uint32(v))
  334. b = b[:3]
  335. case v >= MaxUint24 && v < MaxUint64:
  336. prefix[0] = 0xFE
  337. b = make([]byte, 9)
  338. binary.LittleEndian.PutUint64(b, uint64(v))
  339. }
  340. if len(b) > 1 {
  341. b = append(prefix, b...)
  342. }
  343. return b
  344. }
  345. func (c *Conn) bitmaskToStruct(b []byte, s interface{}) interface{} {
  346. l := len(b)
  347. t := reflect.TypeOf(s)
  348. v := reflect.New(t.Elem()).Elem()
  349. for i := uint(0); i < uint(v.NumField()); i++ {
  350. f := v.Field(int(i))
  351. var v bool
  352. switch {
  353. case l > 4:
  354. x := binary.LittleEndian.Uint64(b)
  355. flag := uint64(1 << i)
  356. v = x&flag > 0
  357. case l > 2:
  358. x := binary.LittleEndian.Uint32(b)
  359. flag := uint32(1 << i)
  360. v = x&flag > 0
  361. case l > 1:
  362. x := binary.LittleEndian.Uint16(b)
  363. flag := uint16(1 << i)
  364. v = x&flag > 0
  365. default:
  366. x := uint(b[0])
  367. flag := uint(1 << i)
  368. v = x&flag > 0
  369. }
  370. f.SetBool(v)
  371. }
  372. return v.Interface()
  373. }
  374. func (c *Conn) structToBitmask(s interface{}) []byte {
  375. t := reflect.TypeOf(s).Elem()
  376. sV := reflect.ValueOf(s).Elem()
  377. fC := uint(t.NumField())
  378. m := uint64(0)
  379. for i := uint(0); i < fC; i++ {
  380. f := sV.Field(int(i))
  381. v := f.Bool()
  382. if v {
  383. m |= 1 << i
  384. }
  385. }
  386. l := uint64(math.Ceil(float64(fC) / 8.0))
  387. b := make([]byte, 8)
  388. binary.LittleEndian.PutUint64(b, m)
  389. switch {
  390. case l > 4: // 64 bits
  391. b = b[:8]
  392. case l > 2: // 32 bits
  393. b = b[:4]
  394. case l > 1: // 16 bits
  395. b = b[:2]
  396. default: // 8 bits
  397. b = b[:1]
  398. }
  399. return b
  400. }
  401. func (c *Conn) putString(t int, v string) uint64 {
  402. b := make([]byte, 0)
  403. switch t {
  404. case TypeFixedString:
  405. b = c.encFixedString(v)
  406. case TypeNullTerminatedString:
  407. b = c.encNullTerminatedString(v)
  408. case TypeRestOfPacketString:
  409. b = c.encRestOfPacketString(v)
  410. }
  411. l, err := c.writeBuf.Write(b)
  412. if err != nil {
  413. c.err = err
  414. }
  415. return uint64(l)
  416. }
  417. func (c *Conn) encNullTerminatedString(v string) []byte {
  418. return append([]byte(v), NullByte)
  419. }
  420. func (c *Conn) encFixedString(v string) []byte {
  421. return []byte(v)
  422. }
  423. func (c *Conn) encRestOfPacketString(v string) []byte {
  424. s := c.encFixedString(v)
  425. return s
  426. }
  427. func (c *Conn) putInt(t int, v uint64, l uint64) uint64 {
  428. c.setupWriteBuffer()
  429. b := make([]byte, 0)
  430. switch t {
  431. case TypeFixedInt:
  432. b = c.encFixedLenInt(v, l)
  433. case TypeLenEncInt:
  434. b = c.encLenEncInt(v)
  435. }
  436. n, err := c.writeBuf.Write(b)
  437. if err != nil {
  438. c.err = err
  439. }
  440. return uint64(n)
  441. }
  442. func (c *Conn) putNullBytes(n uint64) uint64 {
  443. c.setupWriteBuffer()
  444. b := make([]byte, n)
  445. l, err := c.writeBuf.Write(b)
  446. if err != nil {
  447. c.err = err
  448. }
  449. return uint64(l)
  450. }
  451. func (c *Conn) putBytes(v []byte) uint64 {
  452. c.setupWriteBuffer()
  453. l, err := c.writeBuf.Write(v)
  454. if err != nil {
  455. c.err = err
  456. }
  457. return uint64(l)
  458. }
  459. func (c *Conn) Flush() error {
  460. if c.err != nil {
  461. return c.err
  462. }
  463. fmt.Println(string(c.writeBuf.Bytes()))
  464. c.writeBuf = c.addHeader()
  465. _, _ = c.buffer.Write(c.writeBuf.Bytes())
  466. // log all outgoing packets
  467. fmt.Printf(
  468. "\nOUT:\n%08b\n%x\n%s\n\n",
  469. c.writeBuf.Bytes(),
  470. c.writeBuf.Bytes(),
  471. c.writeBuf.Bytes(),
  472. )
  473. if c.buffer.Flush() != nil {
  474. return c.buffer.Flush()
  475. }
  476. c.writeBuf = nil
  477. return nil
  478. }
  479. func (c *Conn) addHeader() *bytes.Buffer {
  480. pl := uint64(c.writeBuf.Len())
  481. sId := uint64(c.sequenceId)
  482. c.sequenceId++
  483. plB := c.encFixedLenInt(pl, 3)
  484. sIdB := c.encFixedLenInt(sId, 1)
  485. return bytes.NewBuffer(append(append(plB, sIdB...), c.writeBuf.Bytes()...))
  486. }
  487. func (c *Conn) setupWriteBuffer() {
  488. if c.writeBuf == nil {
  489. c.writeBuf = bytes.NewBuffer(nil)
  490. }
  491. }
  492. type ErrorPacket struct {
  493. PacketHeader
  494. ErrorCode uint64
  495. ErrorMessage string
  496. SQLStateMarker string
  497. SQLState string
  498. }
  499. func (c *Conn) decodeErrorPacket(ph PacketHeader) (*ErrorPacket, error) {
  500. ep := ErrorPacket{}
  501. ep.PacketHeader = ph
  502. ep.ErrorCode = c.getInt(TypeFixedInt, 2)
  503. ep.SQLStateMarker = c.getString(TypeFixedString, 1)
  504. ep.SQLState = c.getString(TypeFixedString, 5)
  505. ep.ErrorMessage = c.getString(TypeRestOfPacketString, 0)
  506. err := c.scanner.Err()
  507. if err != nil {
  508. return nil, err
  509. }
  510. return &ep, nil
  511. }
  512. func (c *Conn) setConnection(nc net.Conn) {
  513. c.curConn = nc
  514. c.buffer = bufio.NewReadWriter(
  515. bufio.NewReader(c.curConn),
  516. bufio.NewWriter(c.curConn),
  517. )
  518. c.scanner = bufio.NewScanner(c.buffer.Reader)
  519. c.scanner.Split(bufio.ScanBytes)
  520. }