connection.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. package binlog
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "database/sql"
  7. "database/sql/driver"
  8. "encoding/binary"
  9. "encoding/json"
  10. "errors"
  11. "fmt"
  12. "io/ioutil"
  13. "math"
  14. "net"
  15. "reflect"
  16. "strings"
  17. "time"
  18. )
  19. // Misc. Constants
  20. const NullByte byte = 0
  21. var EOF = bytes.NewBuffer([]byte{NullByte})
  22. const MaxPacketSize = MaxUint16
  23. // MySQL Packet Data Types
  24. const TypeNullTerminatedString = int(0)
  25. const TypeFixedString = int(1)
  26. const TypeFixedInt = int(2)
  27. const TypeLenEncInt = int(3)
  28. const TypeRestOfPacketString = int(4)
  29. const TypeLenEncString = int(5)
  30. // Integer Maximums
  31. const MaxUint08 = 1<<8 - 1
  32. const MaxUint16 = 1<<16 - 1
  33. const MaxUint24 = 1<<24 - 1
  34. const MaxUint64 = 1<<64 - 1
  35. // Packet Statuses
  36. const StatusOK = 0x00
  37. const StatusEOF = 0xFE
  38. const StatusErr = 0xFF
  39. const StatusAuth = 0x01
  40. type Config struct {
  41. Host string `json:"host"`
  42. Port int `json:"port"`
  43. User string `json:"user"`
  44. Pass string `json:"password"`
  45. Database string `json:"database"`
  46. SSL bool `json:"ssl"`
  47. SSLCA string `json:"ssl-ca"`
  48. SSLCer string `json:"ssl-cer"`
  49. SSLKey string `json:"ssl-key"`
  50. VerifyCert bool `json:"verify-cert"`
  51. ServerId uint64 `json:"server-id"`
  52. BinlogFile string `json:"binlog-file"`
  53. Timeout time.Duration
  54. }
  55. func newBinlogConfig(dsn string) (*Config, error) {
  56. var err error
  57. b, err := ioutil.ReadFile(dsn)
  58. if err != nil {
  59. return nil, err
  60. }
  61. config := Config{}
  62. err = json.Unmarshal(b, &config)
  63. return &config, err
  64. }
  65. type Conn struct {
  66. Config *Config
  67. curConn net.Conn
  68. tcpConn *net.TCPConn
  69. secTCPConn *tls.Conn
  70. Handshake *Handshake
  71. HandshakeResponse *HandshakeResponse
  72. buffer *bufio.ReadWriter
  73. scanner *bufio.Scanner
  74. err error
  75. sequenceId uint64
  76. writeBuf *bytes.Buffer
  77. StatusFlags *StatusFlags
  78. Listener *net.Listener
  79. packetHeader *PacketHeader
  80. scanPos uint64
  81. }
  82. func newBinlogConn(config *Config) Conn {
  83. return Conn{
  84. Config: config,
  85. sequenceId: 1,
  86. }
  87. }
  88. func (c Conn) Prepare(query string) (driver.Stmt, error) {
  89. return nil, nil
  90. }
  91. func (c Conn) Close() error {
  92. return nil
  93. }
  94. func (c Conn) Begin() (driver.Tx, error) {
  95. return nil, nil
  96. }
  97. type Driver struct{}
  98. func (d Driver) Open(dsn string) (driver.Conn, error) {
  99. config, err := newBinlogConfig(dsn)
  100. if nil != err {
  101. return nil, err
  102. }
  103. c := newBinlogConn(config)
  104. var t interface{}
  105. dialer := net.Dialer{Timeout: c.Config.Timeout}
  106. addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
  107. t, err = dialer.Dial("tcp", addr)
  108. if err != nil {
  109. netErr, ok := err.(net.Error)
  110. if ok && !netErr.Temporary() {
  111. fmt.Printf("Error: %s", netErr.Error())
  112. return nil, err
  113. }
  114. } else {
  115. c.tcpConn = t.(*net.TCPConn)
  116. c.setConnection(t.(net.Conn))
  117. }
  118. err = c.decodeHandshakePacket()
  119. if err != nil {
  120. return nil, err
  121. }
  122. c.HandshakeResponse = c.NewHandshakeResponse()
  123. // If we are on SSL send SSL_Request packet now
  124. if c.Config.SSL {
  125. err = c.writeSSLRequestPacket()
  126. if err != nil {
  127. return nil, err
  128. }
  129. tlsConf := NewClientTLSConfig(
  130. c.Config.SSLKey,
  131. c.Config.SSLCer,
  132. []byte(c.Config.SSLCA),
  133. c.Config.VerifyCert,
  134. c.Config.Host,
  135. )
  136. c.secTCPConn = tls.Client(c.tcpConn, tlsConf)
  137. c.setConnection(c.secTCPConn)
  138. }
  139. err = c.writeHandshakeResponse()
  140. if err != nil {
  141. return nil, err
  142. }
  143. // Listen for auth response.
  144. _, err = c.readPacket()
  145. if err != nil {
  146. return nil, err
  147. }
  148. // Auth was successful.
  149. c.sequenceId = 0
  150. // Register as a slave
  151. err = c.registerAsSlave()
  152. if err != nil {
  153. return nil, err
  154. }
  155. c.sequenceId = 0
  156. _, err = c.readPacket()
  157. if err != nil {
  158. return nil, err
  159. }
  160. err = c.startBinlogStream()
  161. if err != nil {
  162. return nil, err
  163. }
  164. _, err = c.readPacket()
  165. if err != nil {
  166. return nil, err
  167. }
  168. return c, err
  169. }
  170. func (c *Conn) readPacket() (interface{}, error) {
  171. ph, err := c.getPacketHeader()
  172. if err != nil {
  173. return nil, err
  174. }
  175. var res interface{}
  176. switch ph.Status {
  177. case StatusAuth:
  178. res, err := c.decodeAuthMoreDataResponsePacket(ph)
  179. if err != nil {
  180. return nil, err
  181. }
  182. switch res.Data {
  183. case SHA2_FAST_AUTH_SUCCESS:
  184. case SHA2_REQUEST_PUBLIC_KEY:
  185. case SHA2_PERFORM_FULL_AUTHENTICATION:
  186. c.putBytes(append([]byte(c.Config.Pass), NullByte))
  187. if c.Flush() != nil {
  188. return nil, c.Flush()
  189. }
  190. }
  191. case StatusEOF:
  192. fallthrough
  193. case StatusOK:
  194. res, err = c.decodeOKPacket(ph)
  195. if err != nil {
  196. return nil, err
  197. }
  198. case StatusErr:
  199. res, err = c.decodeErrorPacket(ph)
  200. if err != nil {
  201. return nil, err
  202. }
  203. err = fmt.Errorf(
  204. "Error %d: %s",
  205. res.(*ErrorPacket).ErrorCode,
  206. res.(*ErrorPacket).ErrorMessage,
  207. )
  208. return res, err
  209. }
  210. err = c.scanner.Err()
  211. if err != nil {
  212. return nil, err
  213. }
  214. return res, nil
  215. }
  216. type PacketHeader struct {
  217. Length uint64
  218. SequenceID uint64
  219. Status uint64
  220. }
  221. func (c *Conn) getPacketHeader() (*PacketHeader, error) {
  222. ph := PacketHeader{}
  223. ph.Length = c.getInt(TypeFixedInt, 3)
  224. if ph.Length == 0 {
  225. err := errors.New("EOF")
  226. return nil, err
  227. }
  228. ph.SequenceID = c.getInt(TypeFixedInt, 1)
  229. ph.Status = c.getInt(TypeFixedInt, 1)
  230. err := c.scanner.Err()
  231. if err != nil {
  232. return &ph, err
  233. }
  234. c.packetHeader = &ph
  235. c.scanPos = 0
  236. return &ph, nil
  237. }
  238. func init() {
  239. sql.Register("mysql-binlog", &Driver{})
  240. }
  241. func (c *Conn) readBytes(l uint64) *bytes.Buffer {
  242. b := make([]byte, 0)
  243. for i := uint64(0); i < l; i++ {
  244. didScan := c.scanner.Scan()
  245. if !didScan {
  246. err := c.scanner.Err()
  247. if err == nil { // scanner reached EOF
  248. return EOF
  249. } else {
  250. panic(err) // @TODO Handle this gracefully.
  251. }
  252. }
  253. b = append(b, c.scanner.Bytes()...)
  254. }
  255. c.scanPos += uint64(len(b))
  256. return bytes.NewBuffer(b)
  257. }
  258. func (c *Conn) getBytesUntilNull() *bytes.Buffer {
  259. l := uint64(1)
  260. s := c.readBytes(l)
  261. b := s.Bytes()
  262. for {
  263. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  264. break
  265. }
  266. s = c.readBytes(uint64(l))
  267. b = append(b, s.Bytes()...)
  268. }
  269. return bytes.NewBuffer(b)
  270. }
  271. func (c *Conn) discardBytes(l uint64) {
  272. c.readBytes(l)
  273. }
  274. func (c *Conn) getInt(t int, l uint64) uint64 {
  275. var v uint64
  276. switch t {
  277. case TypeFixedInt:
  278. v = c.decFixedInt(l)
  279. case TypeLenEncInt:
  280. v = c.decLenEncInt()
  281. default:
  282. v = 0
  283. }
  284. return v
  285. }
  286. func (c *Conn) getString(t int, l uint64) string {
  287. var v string
  288. switch t {
  289. case TypeFixedString:
  290. v = c.decFixedString(l)
  291. case TypeLenEncString:
  292. v = string(c.decLenEncInt())
  293. case TypeNullTerminatedString:
  294. v = c.decNullTerminatedString()
  295. case TypeRestOfPacketString:
  296. v = c.decRestOfPacketString()
  297. default:
  298. v = ""
  299. }
  300. return v
  301. }
  302. func (c *Conn) decRestOfPacketString() string {
  303. b := c.getRemainingBytes()
  304. return b.String()
  305. }
  306. func (c *Conn) getRemainingBytes() *bytes.Buffer {
  307. l := (c.packetHeader.Length - 1) - c.scanPos
  308. b := c.readBytes(l)
  309. return b
  310. }
  311. func (c *Conn) decNullTerminatedString() string {
  312. b := c.getBytesUntilNull()
  313. return strings.TrimRight(b.String(), string(NullByte))
  314. }
  315. func (c *Conn) decFixedString(l uint64) string {
  316. b := c.readBytes(l)
  317. return b.String()
  318. }
  319. func (c *Conn) decLenEncInt() uint64 {
  320. var l uint16
  321. b := c.readBytes(1)
  322. br := bytes.NewReader(b.Bytes())
  323. _ = binary.Read(br, binary.LittleEndian, &l)
  324. if l > 0 {
  325. return c.decFixedInt(uint64(l))
  326. } else {
  327. return 0
  328. }
  329. }
  330. func (c *Conn) decFixedInt(l uint64) uint64 {
  331. var i uint64
  332. b := c.readBytes(l)
  333. if l <= 2 {
  334. var x uint16
  335. pb := c.padBytes(2, b.Bytes())
  336. br := bytes.NewReader(pb)
  337. _ = binary.Read(br, binary.LittleEndian, &x)
  338. i = uint64(x)
  339. } else if l <= 4 {
  340. var x uint32
  341. pb := c.padBytes(4, b.Bytes())
  342. br := bytes.NewReader(pb)
  343. _ = binary.Read(br, binary.LittleEndian, &x)
  344. i = uint64(x)
  345. } else if l <= 8 {
  346. var x uint64
  347. pb := c.padBytes(8, b.Bytes())
  348. br := bytes.NewReader(pb)
  349. _ = binary.Read(br, binary.LittleEndian, &x)
  350. i = x
  351. }
  352. return i
  353. }
  354. func (c *Conn) padBytes(l int, b []byte) []byte {
  355. bl := len(b)
  356. pl := l - bl
  357. for i := 0; i < pl; i++ {
  358. b = append(b, NullByte)
  359. }
  360. return b
  361. }
  362. func (c *Conn) encFixedLenInt(v uint64, l uint64) []byte {
  363. b := make([]byte, 8)
  364. binary.LittleEndian.PutUint64(b, v)
  365. return b[:l]
  366. }
  367. func (c *Conn) encLenEncInt(v uint64) []byte {
  368. prefix := make([]byte, 1)
  369. var b []byte
  370. switch {
  371. case v < MaxUint08:
  372. b = make([]byte, 2)
  373. binary.LittleEndian.PutUint16(b, uint16(v))
  374. b = b[:1]
  375. case v >= MaxUint08 && v < MaxUint16:
  376. prefix[0] = 0xFC
  377. b = make([]byte, 3)
  378. binary.LittleEndian.PutUint16(b, uint16(v))
  379. b = b[:2]
  380. case v >= MaxUint16 && v < MaxUint24:
  381. prefix[0] = 0xFD
  382. b = make([]byte, 4)
  383. binary.LittleEndian.PutUint32(b, uint32(v))
  384. b = b[:3]
  385. case v >= MaxUint24 && v < MaxUint64:
  386. prefix[0] = 0xFE
  387. b = make([]byte, 9)
  388. binary.LittleEndian.PutUint64(b, uint64(v))
  389. }
  390. if len(b) > 1 {
  391. b = append(prefix, b...)
  392. }
  393. return b
  394. }
  395. func (c *Conn) bitmaskToStruct(b []byte, s interface{}) interface{} {
  396. l := len(b)
  397. t := reflect.TypeOf(s)
  398. v := reflect.New(t.Elem()).Elem()
  399. for i := uint(0); i < uint(v.NumField()); i++ {
  400. f := v.Field(int(i))
  401. var v bool
  402. switch {
  403. case l > 4:
  404. x := binary.LittleEndian.Uint64(b)
  405. flag := uint64(1 << i)
  406. v = x&flag > 0
  407. case l > 2:
  408. x := binary.LittleEndian.Uint32(b)
  409. flag := uint32(1 << i)
  410. v = x&flag > 0
  411. case l > 1:
  412. x := binary.LittleEndian.Uint16(b)
  413. flag := uint16(1 << i)
  414. v = x&flag > 0
  415. default:
  416. x := uint(b[0])
  417. flag := uint(1 << i)
  418. v = x&flag > 0
  419. }
  420. f.SetBool(v)
  421. }
  422. return v.Interface()
  423. }
  424. func (c *Conn) structToBitmask(s interface{}) []byte {
  425. t := reflect.TypeOf(s).Elem()
  426. sV := reflect.ValueOf(s).Elem()
  427. fC := uint(t.NumField())
  428. m := uint64(0)
  429. for i := uint(0); i < fC; i++ {
  430. f := sV.Field(int(i))
  431. v := f.Bool()
  432. if v {
  433. m |= 1 << i
  434. }
  435. }
  436. l := uint64(math.Ceil(float64(fC) / 8.0))
  437. b := make([]byte, 8)
  438. binary.LittleEndian.PutUint64(b, m)
  439. switch {
  440. case l > 4: // 64 bits
  441. b = b[:8]
  442. case l > 2: // 32 bits
  443. b = b[:4]
  444. case l > 1: // 16 bits
  445. b = b[:2]
  446. default: // 8 bits
  447. b = b[:1]
  448. }
  449. return b
  450. }
  451. func (c *Conn) putString(t int, v string) uint64 {
  452. b := make([]byte, 0)
  453. switch t {
  454. case TypeFixedString:
  455. b = c.encFixedString(v)
  456. case TypeLenEncString:
  457. b = c.encLenEncString(v)
  458. case TypeNullTerminatedString:
  459. b = c.encNullTerminatedString(v)
  460. case TypeRestOfPacketString:
  461. b = c.encRestOfPacketString(v)
  462. }
  463. l, err := c.writeBuf.Write(b)
  464. if err != nil {
  465. c.err = err
  466. }
  467. return uint64(l)
  468. }
  469. func (c *Conn) encLenEncString(v string) []byte {
  470. l := uint64(len(v))
  471. b := c.encLenEncInt(l)
  472. return append(b, c.encFixedString(v)...)
  473. }
  474. func (c *Conn) encNullTerminatedString(v string) []byte {
  475. return append([]byte(v), NullByte)
  476. }
  477. func (c *Conn) encFixedString(v string) []byte {
  478. return []byte(v)
  479. }
  480. func (c *Conn) encRestOfPacketString(v string) []byte {
  481. s := c.encFixedString(v)
  482. return s
  483. }
  484. func (c *Conn) putInt(t int, v uint64, l uint64) uint64 {
  485. c.setupWriteBuffer()
  486. b := make([]byte, 0)
  487. switch t {
  488. case TypeFixedInt:
  489. b = c.encFixedLenInt(v, l)
  490. case TypeLenEncInt:
  491. b = c.encLenEncInt(v)
  492. }
  493. n, err := c.writeBuf.Write(b)
  494. if err != nil {
  495. c.err = err
  496. }
  497. return uint64(n)
  498. }
  499. func (c *Conn) putNullBytes(n uint64) uint64 {
  500. c.setupWriteBuffer()
  501. b := make([]byte, n)
  502. l, err := c.writeBuf.Write(b)
  503. if err != nil {
  504. c.err = err
  505. }
  506. return uint64(l)
  507. }
  508. func (c *Conn) putBytes(v []byte) uint64 {
  509. c.setupWriteBuffer()
  510. l, err := c.writeBuf.Write(v)
  511. if err != nil {
  512. c.err = err
  513. }
  514. return uint64(l)
  515. }
  516. func (c *Conn) Flush() error {
  517. if c.err != nil {
  518. return c.err
  519. }
  520. c.writeBuf = c.addHeader()
  521. //fmt.Printf("c.writeBuf = %x\n", c.writeBuf.Bytes())
  522. _, _ = c.buffer.Write(c.writeBuf.Bytes())
  523. if c.buffer.Flush() != nil {
  524. return c.buffer.Flush()
  525. }
  526. c.writeBuf = nil
  527. return nil
  528. }
  529. func (c *Conn) addHeader() *bytes.Buffer {
  530. pl := uint64(c.writeBuf.Len())
  531. sId := uint64(c.sequenceId)
  532. c.sequenceId++
  533. plB := c.encFixedLenInt(pl, 3)
  534. sIdB := c.encFixedLenInt(sId, 1)
  535. return bytes.NewBuffer(append(append(plB, sIdB...), c.writeBuf.Bytes()...))
  536. }
  537. func (c *Conn) setupWriteBuffer() {
  538. if c.writeBuf == nil {
  539. c.writeBuf = bytes.NewBuffer(nil)
  540. }
  541. }
  542. type StatusFlags struct {
  543. }
  544. type OKPacket struct {
  545. *PacketHeader
  546. Header uint64
  547. AffectedRows uint64
  548. LastInsertID uint64
  549. StatusFlags uint64
  550. Warnings uint64
  551. Info string
  552. SessionStateInfo string
  553. }
  554. func (c *Conn) decodeOKPacket(ph *PacketHeader) (*OKPacket, error) {
  555. op := OKPacket{}
  556. op.PacketHeader = ph
  557. op.Header = ph.Status
  558. op.AffectedRows = c.getInt(TypeLenEncInt, 0)
  559. op.LastInsertID = c.getInt(TypeLenEncInt, 0)
  560. if c.HandshakeResponse.ClientFlag.Protocol41 {
  561. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  562. op.Warnings = c.getInt(TypeFixedInt, 1)
  563. } else if c.HandshakeResponse.ClientFlag.Transactions {
  564. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  565. }
  566. if c.HandshakeResponse.ClientFlag.SessionTrack {
  567. op.Info = c.getString(TypeRestOfPacketString, 0)
  568. } else {
  569. op.Info = c.getString(TypeRestOfPacketString, 0)
  570. }
  571. return &op, nil
  572. }
  573. type ErrorPacket struct {
  574. *PacketHeader
  575. ErrorCode uint64
  576. ErrorMessage string
  577. SQLStateMarker string
  578. SQLState string
  579. }
  580. func (c *Conn) decodeErrorPacket(ph *PacketHeader) (*ErrorPacket, error) {
  581. ep := ErrorPacket{}
  582. ep.PacketHeader = ph
  583. ep.ErrorCode = c.getInt(TypeFixedInt, 2)
  584. ep.SQLStateMarker = c.getString(TypeFixedString, 1)
  585. ep.SQLState = c.getString(TypeFixedString, 5)
  586. ep.ErrorMessage = c.getString(TypeRestOfPacketString, 0)
  587. err := c.scanner.Err()
  588. if err != nil {
  589. return nil, err
  590. }
  591. return &ep, nil
  592. }
  593. func (c *Conn) setConnection(nc net.Conn) {
  594. c.curConn = nc
  595. c.buffer = bufio.NewReadWriter(
  596. bufio.NewReader(c.curConn),
  597. bufio.NewWriter(c.curConn),
  598. )
  599. c.scanner = bufio.NewScanner(c.buffer.Reader)
  600. c.scanner.Split(bufio.ScanBytes)
  601. }