connection.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. package binlog
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "database/sql"
  7. "database/sql/driver"
  8. "encoding/binary"
  9. "encoding/json"
  10. "fmt"
  11. "io/ioutil"
  12. "math"
  13. "net"
  14. "reflect"
  15. "strings"
  16. "time"
  17. )
  18. // NullByte is a constant representing a null byte in the MySQL protocol.
  19. const NullByte byte = 0
  20. // MaxPacketSize is the maximum size of a MySQL protocol packet.
  21. const MaxPacketSize = MaxUint16
  22. // TypeNullTerminatedString represents the null terminated string type in the MySQL protocol.
  23. const TypeNullTerminatedString = int(0)
  24. // TypeFixedString represents the fixed string type in the MySQL protocol.
  25. const TypeFixedString = int(1)
  26. // TypeFixedInt represents the fixed integer type in the MySQL protocol.
  27. const TypeFixedInt = int(2)
  28. // TypeLenEncInt represents the length encoded integer type in the MySQL protocol.
  29. const TypeLenEncInt = int(3)
  30. // TypeRestOfPacketString represents the rest of packet string type in the MySQL protocol.
  31. const TypeRestOfPacketString = int(4)
  32. // TypeLenEncString represents the length encoded string type in the MySQL protocol.
  33. const TypeLenEncString = int(5)
  34. // MaxUint08 is the largest unsigned eight byte integer.
  35. const MaxUint08 = 1<<8 - 1
  36. // MaxUint16 is is the largest unsigned sixteen byte integer.
  37. const MaxUint16 = 1<<16 - 1
  38. // MaxUint24 is is the largest unsigned 24 byte integer.
  39. const MaxUint24 = 1<<24 - 1
  40. // MaxUint64 is is the largest unsigned 64 byte integer.
  41. const MaxUint64 = 1<<64 - 1
  42. // StatusOK indicates an OK packet from the MySQL protocol.
  43. const StatusOK = 0x00
  44. // StatusEOF indicates an end of file packet from the MySQL protocol.
  45. const StatusEOF = 0xFE
  46. // StatusErr indicates an error packet from the MySQL protocol.
  47. const StatusErr = 0xFF
  48. // StatusAuth indicates an authorization packet from the MySQL protocol.
  49. const StatusAuth = 0x01
  50. // Config represents the required parameters required to make a MySQL connection.
  51. type Config struct {
  52. Host string `json:"host"`
  53. Port int `json:"port"`
  54. User string `json:"user"`
  55. Pass string `json:"password"`
  56. Database string `json:"database"`
  57. SSL bool `json:"ssl"`
  58. SSLCA string `json:"ssl-ca"`
  59. SSLCer string `json:"ssl-cer"`
  60. SSLKey string `json:"ssl-key"`
  61. VerifyCert bool `json:"verify-cert"`
  62. ServerID uint64 `json:"server-id"`
  63. BinlogFile string `json:"binlog-file"`
  64. Timeout time.Duration
  65. }
  66. func newBinlogConfig(dsn string) (*Config, error) {
  67. var err error
  68. b, err := ioutil.ReadFile(dsn)
  69. if err != nil {
  70. return nil, err
  71. }
  72. config := Config{}
  73. err = json.Unmarshal(b, &config)
  74. return &config, err
  75. }
  76. // Conn represents a connection to a MySQL server.
  77. type Conn struct {
  78. Config *Config
  79. curConn net.Conn
  80. tcpConn *net.TCPConn
  81. secTCPConn *tls.Conn
  82. Handshake *Handshake
  83. HandshakeResponse *HandshakeResponse
  84. buffer *bufio.ReadWriter
  85. scanner *bufio.Scanner
  86. err error
  87. sequenceID uint64
  88. writeBuf *bytes.Buffer
  89. StatusFlags *StatusFlags
  90. Listener *net.Listener
  91. packetHeader *PacketHeader
  92. scanPos uint64
  93. }
  94. func newBinlogConn(config *Config) Conn {
  95. return Conn{
  96. Config: config,
  97. sequenceID: 1,
  98. }
  99. }
  100. // Prepare is not yet implemented.
  101. func (c Conn) Prepare(query string) (driver.Stmt, error) {
  102. return nil, nil
  103. }
  104. // Close is not yet implemented.s
  105. func (c Conn) Close() error {
  106. return nil
  107. }
  108. // Begin is not yet implemented.
  109. func (c Conn) Begin() (driver.Tx, error) {
  110. return nil, nil
  111. }
  112. // Driver is not used.
  113. type Driver struct{}
  114. // Open creates the connection to the MySQL server.
  115. func (d Driver) Open(dsn string) (driver.Conn, error) {
  116. config, err := newBinlogConfig(dsn)
  117. if nil != err {
  118. return nil, err
  119. }
  120. c := newBinlogConn(config)
  121. var t interface{}
  122. dialer := net.Dialer{Timeout: c.Config.Timeout}
  123. addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
  124. t, err = dialer.Dial("tcp", addr)
  125. if err != nil {
  126. netErr, ok := err.(net.Error)
  127. if ok && !netErr.Temporary() {
  128. fmt.Printf("Error: %s", netErr.Error())
  129. return nil, err
  130. }
  131. } else {
  132. c.tcpConn = t.(*net.TCPConn)
  133. c.setConnection(t.(net.Conn))
  134. }
  135. err = c.decodeHandshakePacket()
  136. if err != nil {
  137. return nil, err
  138. }
  139. c.HandshakeResponse = c.NewHandshakeResponse()
  140. // If we are on SSL send SSL_Request packet now
  141. if c.Config.SSL {
  142. err = c.writeSSLRequestPacket()
  143. if err != nil {
  144. return nil, err
  145. }
  146. tlsConf := NewClientTLSConfig(
  147. c.Config.SSLKey,
  148. c.Config.SSLCer,
  149. []byte(c.Config.SSLCA),
  150. c.Config.VerifyCert,
  151. c.Config.Host,
  152. )
  153. c.secTCPConn = tls.Client(c.tcpConn, tlsConf)
  154. c.setConnection(c.secTCPConn)
  155. }
  156. err = c.writeHandshakeResponse()
  157. if err != nil {
  158. return nil, err
  159. }
  160. // Listen for auth response.
  161. _, err = c.readPacket()
  162. if err != nil {
  163. return nil, err
  164. }
  165. // Auth was successful.
  166. c.sequenceID = 0
  167. // Register as a slave
  168. err = c.registerAsSlave()
  169. if err != nil {
  170. return nil, err
  171. }
  172. c.sequenceID = 0
  173. _, err = c.readPacket()
  174. if err != nil {
  175. return nil, err
  176. }
  177. err = c.startBinlogStream()
  178. if err != nil {
  179. return nil, err
  180. }
  181. err = c.listenForBinlog()
  182. if err != nil {
  183. return nil, err
  184. }
  185. return c, err
  186. }
  187. func (c *Conn) readPacket() (interface{}, error) {
  188. ph, err := c.getPacketHeader()
  189. if err != nil {
  190. return nil, err
  191. }
  192. var res interface{}
  193. switch ph.Status {
  194. case StatusAuth:
  195. res, err := c.decodeAuthMoreDataResponsePacket(ph)
  196. if err != nil {
  197. return nil, err
  198. }
  199. switch res.Data {
  200. case Sha2FastAuthSuccess:
  201. case Sha2RequestPublicKey:
  202. case Sha2PerformFullAuthentication:
  203. c.putBytes(append([]byte(c.Config.Pass), NullByte))
  204. if c.Flush() != nil {
  205. return nil, c.Flush()
  206. }
  207. }
  208. case StatusEOF:
  209. fallthrough
  210. case StatusOK:
  211. res, err = c.decodeOKPacket(ph)
  212. if err != nil {
  213. return nil, err
  214. }
  215. case StatusErr:
  216. res, err = c.decodeErrorPacket(ph)
  217. if err != nil {
  218. return nil, err
  219. }
  220. err = fmt.Errorf(
  221. "error %d: %s",
  222. res.(*ErrorPacket).ErrorCode,
  223. res.(*ErrorPacket).ErrorMessage,
  224. )
  225. return res, err
  226. default:
  227. fmt.Printf("Unknown PacketHeader: %+v\n", ph)
  228. }
  229. err = c.scanner.Err()
  230. if err != nil {
  231. return nil, err
  232. }
  233. return res, nil
  234. }
  235. // PacketHeader represents the beginning of a MySQL protocol packet.
  236. type PacketHeader struct {
  237. Length uint64
  238. SequenceID uint64
  239. Status uint64
  240. }
  241. func (c *Conn) getPacketHeader() (*PacketHeader, error) {
  242. ph := PacketHeader{}
  243. ph.Length = c.getInt(TypeFixedInt, 3)
  244. ph.SequenceID = c.getInt(TypeFixedInt, 1)
  245. ph.Status = c.getInt(TypeFixedInt, 1)
  246. fmt.Printf("ph.Status = %+v\n", ph.Status)
  247. err := c.scanner.Err()
  248. if err != nil {
  249. return &ph, err
  250. }
  251. c.packetHeader = &ph
  252. c.scanPos = 0
  253. return &ph, nil
  254. }
  255. func init() {
  256. sql.Register("mysql-binlog", &Driver{})
  257. }
  258. func (c *Conn) readBytes(l uint64) *bytes.Buffer {
  259. b := make([]byte, 0)
  260. for i := uint64(0); i < l; i++ {
  261. didScan := c.scanner.Scan()
  262. if !didScan {
  263. err := c.scanner.Err()
  264. if err != nil {
  265. panic(err)
  266. }
  267. } else {
  268. b = append(b, c.scanner.Bytes()...)
  269. }
  270. }
  271. c.scanPos += uint64(len(b))
  272. return bytes.NewBuffer(b)
  273. }
  274. func (c *Conn) getBytesUntilNull() *bytes.Buffer {
  275. l := uint64(1)
  276. s := c.readBytes(l)
  277. b := s.Bytes()
  278. for {
  279. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  280. break
  281. }
  282. s = c.readBytes(l)
  283. b = append(b, s.Bytes()...)
  284. }
  285. return bytes.NewBuffer(b)
  286. }
  287. func (c *Conn) discardBytes(l uint64) {
  288. c.readBytes(l)
  289. }
  290. func (c *Conn) getInt(t int, l uint64) uint64 {
  291. var v uint64
  292. switch t {
  293. case TypeFixedInt:
  294. v = c.decFixedInt(l)
  295. case TypeLenEncInt:
  296. v = c.decLenEncInt()
  297. default:
  298. v = 0
  299. }
  300. return v
  301. }
  302. func (c *Conn) getString(t int, l uint64) string {
  303. var v string
  304. switch t {
  305. case TypeFixedString:
  306. v = c.decFixedString(l)
  307. case TypeLenEncString:
  308. v = string(c.decLenEncInt())
  309. case TypeNullTerminatedString:
  310. v = c.decNullTerminatedString()
  311. case TypeRestOfPacketString:
  312. v = c.decRestOfPacketString()
  313. default:
  314. v = ""
  315. }
  316. return v
  317. }
  318. func (c *Conn) decRestOfPacketString() string {
  319. b := c.getRemainingBytes()
  320. return b.String()
  321. }
  322. func (c *Conn) getRemainingBytes() *bytes.Buffer {
  323. l := (c.packetHeader.Length - 1) - c.scanPos
  324. b := c.readBytes(l)
  325. return b
  326. }
  327. func (c *Conn) decNullTerminatedString() string {
  328. b := c.getBytesUntilNull()
  329. return strings.TrimRight(b.String(), string(NullByte))
  330. }
  331. func (c *Conn) decFixedString(l uint64) string {
  332. b := c.readBytes(l)
  333. return b.String()
  334. }
  335. func (c *Conn) decLenEncInt() uint64 {
  336. var l uint16
  337. b := c.readBytes(1)
  338. br := bytes.NewReader(b.Bytes())
  339. _ = binary.Read(br, binary.LittleEndian, &l)
  340. if l > 0 {
  341. return c.decFixedInt(uint64(l))
  342. }
  343. return 0
  344. }
  345. func (c *Conn) decFixedInt(l uint64) uint64 {
  346. var i uint64
  347. b := c.readBytes(l)
  348. if l <= 2 {
  349. var x uint16
  350. pb := c.padBytes(2, b.Bytes())
  351. br := bytes.NewReader(pb)
  352. _ = binary.Read(br, binary.LittleEndian, &x)
  353. i = uint64(x)
  354. } else if l <= 4 {
  355. var x uint32
  356. pb := c.padBytes(4, b.Bytes())
  357. br := bytes.NewReader(pb)
  358. _ = binary.Read(br, binary.LittleEndian, &x)
  359. i = uint64(x)
  360. } else if l <= 8 {
  361. var x uint64
  362. pb := c.padBytes(8, b.Bytes())
  363. br := bytes.NewReader(pb)
  364. _ = binary.Read(br, binary.LittleEndian, &x)
  365. i = x
  366. }
  367. return i
  368. }
  369. func (c *Conn) padBytes(l int, b []byte) []byte {
  370. bl := len(b)
  371. pl := l - bl
  372. for i := 0; i < pl; i++ {
  373. b = append(b, NullByte)
  374. }
  375. return b
  376. }
  377. func (c *Conn) encFixedLenInt(v uint64, l uint64) []byte {
  378. b := make([]byte, 8)
  379. binary.LittleEndian.PutUint64(b, v)
  380. return b[:l]
  381. }
  382. func (c *Conn) encLenEncInt(v uint64) []byte {
  383. prefix := make([]byte, 1)
  384. var b []byte
  385. switch {
  386. case v < MaxUint08:
  387. b = make([]byte, 2)
  388. binary.LittleEndian.PutUint16(b, uint16(v))
  389. b = b[:1]
  390. case v >= MaxUint08 && v < MaxUint16:
  391. prefix[0] = 0xFC
  392. b = make([]byte, 3)
  393. binary.LittleEndian.PutUint16(b, uint16(v))
  394. b = b[:2]
  395. case v >= MaxUint16 && v < MaxUint24:
  396. prefix[0] = 0xFD
  397. b = make([]byte, 4)
  398. binary.LittleEndian.PutUint32(b, uint32(v))
  399. b = b[:3]
  400. case v >= MaxUint24 && v < MaxUint64:
  401. prefix[0] = 0xFE
  402. b = make([]byte, 9)
  403. binary.LittleEndian.PutUint64(b, v)
  404. }
  405. if len(b) > 1 {
  406. b = append(prefix, b...)
  407. }
  408. return b
  409. }
  410. func (c *Conn) bitmaskToStruct(b []byte, s interface{}) interface{} {
  411. l := len(b)
  412. t := reflect.TypeOf(s)
  413. v := reflect.New(t.Elem()).Elem()
  414. for i := uint(0); i < uint(v.NumField()); i++ {
  415. f := v.Field(int(i))
  416. var v bool
  417. switch {
  418. case l > 4:
  419. x := binary.LittleEndian.Uint64(b)
  420. flag := uint64(1 << i)
  421. v = x&flag > 0
  422. case l > 2:
  423. x := binary.LittleEndian.Uint32(b)
  424. flag := uint32(1 << i)
  425. v = x&flag > 0
  426. case l > 1:
  427. x := binary.LittleEndian.Uint16(b)
  428. flag := uint16(1 << i)
  429. v = x&flag > 0
  430. default:
  431. x := uint(b[0])
  432. flag := uint(1 << i)
  433. v = x&flag > 0
  434. }
  435. f.SetBool(v)
  436. }
  437. return v.Interface()
  438. }
  439. func (c *Conn) structToBitmask(s interface{}) []byte {
  440. t := reflect.TypeOf(s).Elem()
  441. sV := reflect.ValueOf(s).Elem()
  442. fC := uint(t.NumField())
  443. m := uint64(0)
  444. for i := uint(0); i < fC; i++ {
  445. f := sV.Field(int(i))
  446. v := f.Bool()
  447. if v {
  448. m |= 1 << i
  449. }
  450. }
  451. l := uint64(math.Ceil(float64(fC) / 8.0))
  452. b := make([]byte, 8)
  453. binary.LittleEndian.PutUint64(b, m)
  454. switch {
  455. case l > 4: // 64 bits
  456. b = b[:8]
  457. case l > 2: // 32 bits
  458. b = b[:4]
  459. case l > 1: // 16 bits
  460. b = b[:2]
  461. default: // 8 bits
  462. b = b[:1]
  463. }
  464. return b
  465. }
  466. func (c *Conn) putString(t int, v string) uint64 {
  467. b := make([]byte, 0)
  468. switch t {
  469. case TypeFixedString:
  470. b = c.encFixedString(v)
  471. case TypeLenEncString:
  472. b = c.encLenEncString(v)
  473. case TypeNullTerminatedString:
  474. b = c.encNullTerminatedString(v)
  475. case TypeRestOfPacketString:
  476. b = c.encRestOfPacketString(v)
  477. }
  478. l, err := c.writeBuf.Write(b)
  479. if err != nil {
  480. c.err = err
  481. }
  482. return uint64(l)
  483. }
  484. func (c *Conn) encLenEncString(v string) []byte {
  485. l := uint64(len(v))
  486. b := c.encLenEncInt(l)
  487. return append(b, c.encFixedString(v)...)
  488. }
  489. func (c *Conn) encNullTerminatedString(v string) []byte {
  490. return append([]byte(v), NullByte)
  491. }
  492. func (c *Conn) encFixedString(v string) []byte {
  493. return []byte(v)
  494. }
  495. func (c *Conn) encRestOfPacketString(v string) []byte {
  496. s := c.encFixedString(v)
  497. return s
  498. }
  499. func (c *Conn) putInt(t int, v uint64, l uint64) uint64 {
  500. c.setupWriteBuffer()
  501. b := make([]byte, 0)
  502. switch t {
  503. case TypeFixedInt:
  504. b = c.encFixedLenInt(v, l)
  505. case TypeLenEncInt:
  506. b = c.encLenEncInt(v)
  507. }
  508. n, err := c.writeBuf.Write(b)
  509. if err != nil {
  510. c.err = err
  511. }
  512. return uint64(n)
  513. }
  514. func (c *Conn) putNullBytes(n uint64) uint64 {
  515. c.setupWriteBuffer()
  516. b := make([]byte, n)
  517. l, err := c.writeBuf.Write(b)
  518. if err != nil {
  519. c.err = err
  520. }
  521. return uint64(l)
  522. }
  523. func (c *Conn) putBytes(v []byte) uint64 {
  524. c.setupWriteBuffer()
  525. l, err := c.writeBuf.Write(v)
  526. if err != nil {
  527. c.err = err
  528. }
  529. return uint64(l)
  530. }
  531. // Flush clears the write buffer.
  532. func (c *Conn) Flush() error {
  533. if c.err != nil {
  534. return c.err
  535. }
  536. c.writeBuf = c.addHeader()
  537. //fmt.Printf("c.writeBuf = %x\n", c.writeBuf.Bytes())
  538. _, _ = c.buffer.Write(c.writeBuf.Bytes())
  539. if c.buffer.Flush() != nil {
  540. return c.buffer.Flush()
  541. }
  542. c.writeBuf = nil
  543. return nil
  544. }
  545. func (c *Conn) addHeader() *bytes.Buffer {
  546. pl := uint64(c.writeBuf.Len())
  547. sID := c.sequenceID
  548. c.sequenceID++
  549. var plB = c.encFixedLenInt(pl, 3)
  550. var sIDB = c.encFixedLenInt(sID, 1)
  551. return bytes.NewBuffer(append(append(plB, sIDB...), c.writeBuf.Bytes()...))
  552. }
  553. func (c *Conn) setupWriteBuffer() {
  554. if c.writeBuf == nil {
  555. c.writeBuf = bytes.NewBuffer(nil)
  556. }
  557. }
  558. // StatusFlags is not used
  559. type StatusFlags struct {
  560. }
  561. // OKPacket represents an OK packet in the MySQL protocol.
  562. type OKPacket struct {
  563. *PacketHeader
  564. Header uint64
  565. AffectedRows uint64
  566. LastInsertID uint64
  567. StatusFlags uint64
  568. Warnings uint64
  569. Info string
  570. SessionStateInfo string
  571. }
  572. func (c *Conn) decodeOKPacket(ph *PacketHeader) (*OKPacket, error) {
  573. op := OKPacket{}
  574. op.PacketHeader = ph
  575. op.Header = ph.Status
  576. op.AffectedRows = c.getInt(TypeLenEncInt, 0)
  577. op.LastInsertID = c.getInt(TypeLenEncInt, 0)
  578. if c.HandshakeResponse.ClientFlag.Protocol41 {
  579. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  580. op.Warnings = c.getInt(TypeFixedInt, 1)
  581. } else if c.HandshakeResponse.ClientFlag.Transactions {
  582. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  583. }
  584. if c.HandshakeResponse.ClientFlag.SessionTrack {
  585. op.Info = c.getString(TypeRestOfPacketString, 0)
  586. } else {
  587. op.Info = c.getString(TypeRestOfPacketString, 0)
  588. }
  589. return &op, nil
  590. }
  591. // ErrorPacket represents an error packet in the MySQL protocol.
  592. type ErrorPacket struct {
  593. *PacketHeader
  594. ErrorCode uint64
  595. ErrorMessage string
  596. SQLStateMarker string
  597. SQLState string
  598. }
  599. func (c *Conn) decodeErrorPacket(ph *PacketHeader) (*ErrorPacket, error) {
  600. ep := ErrorPacket{}
  601. ep.PacketHeader = ph
  602. ep.ErrorCode = c.getInt(TypeFixedInt, 2)
  603. ep.SQLStateMarker = c.getString(TypeFixedString, 1)
  604. ep.SQLState = c.getString(TypeFixedString, 5)
  605. ep.ErrorMessage = c.getString(TypeRestOfPacketString, 0)
  606. err := c.scanner.Err()
  607. if err != nil {
  608. return nil, err
  609. }
  610. return &ep, nil
  611. }
  612. func (c *Conn) setConnection(nc net.Conn) {
  613. c.curConn = nc
  614. c.buffer = bufio.NewReadWriter(
  615. bufio.NewReader(c.curConn),
  616. bufio.NewWriter(c.curConn),
  617. )
  618. c.scanner = bufio.NewScanner(c.buffer.Reader)
  619. c.scanner.Split(bufio.ScanBytes)
  620. }