connection.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. package binlog
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "database/sql"
  7. "database/sql/driver"
  8. "encoding/binary"
  9. "encoding/json"
  10. "errors"
  11. "fmt"
  12. "io/ioutil"
  13. "math"
  14. "net"
  15. "reflect"
  16. "strings"
  17. "time"
  18. )
  19. // Misc. Constants
  20. const NullByte byte = 0
  21. const MaxPacketSize = MaxUint16
  22. // MySQL Packet Data Types
  23. const TypeNullTerminatedString = int(0)
  24. const TypeFixedString = int(1)
  25. const TypeFixedInt = int(2)
  26. const TypeLenEncInt = int(3)
  27. const TypeRestOfPacketString = int(4)
  28. const TypeLenEncString = int(5)
  29. // Integer Maximums
  30. const MaxUint08 = 1<<8 - 1
  31. const MaxUint16 = 1<<16 - 1
  32. const MaxUint24 = 1<<24 - 1
  33. const MaxUint64 = 1<<64 - 1
  34. // Packet Statuses
  35. const StatusOK = 0x00
  36. const StatusEOF = 0xFE
  37. const StatusErr = 0xFF
  38. const StatusAuth = 0x01
  39. type Config struct {
  40. Host string `json:"host"`
  41. Port int `json:"port"`
  42. User string `json:"user"`
  43. Pass string `json:"password"`
  44. Database string `json:"database"`
  45. SSL bool `json:"ssl"`
  46. SSLCA string `json:"ssl-ca"`
  47. SSLCer string `json:"ssl-cer"`
  48. SSLKey string `json:"ssl-key"`
  49. VerifyCert bool `json:"verify-cert"`
  50. Timeout time.Duration
  51. }
  52. func newBinlogConfig(dsn string) (*Config, error) {
  53. var err error
  54. b, err := ioutil.ReadFile(dsn)
  55. if err != nil {
  56. return nil, err
  57. }
  58. config := Config{}
  59. err = json.Unmarshal(b, &config)
  60. return &config, err
  61. }
  62. type Conn struct {
  63. Config *Config
  64. curConn net.Conn
  65. tcpConn *net.TCPConn
  66. secTCPConn *tls.Conn
  67. Handshake *Handshake
  68. HandshakeResponse *HandshakeResponse
  69. buffer *bufio.ReadWriter
  70. scanner *bufio.Scanner
  71. err error
  72. sequenceId uint64
  73. writeBuf *bytes.Buffer
  74. StausFlags *StatusFlags
  75. }
  76. func newBinlogConn(config *Config) Conn {
  77. return Conn{
  78. Config: config,
  79. sequenceId: 1,
  80. }
  81. }
  82. func (c Conn) Prepare(query string) (driver.Stmt, error) {
  83. return nil, nil
  84. }
  85. func (c Conn) Close() error {
  86. return nil
  87. }
  88. func (c Conn) Begin() (driver.Tx, error) {
  89. return nil, nil
  90. }
  91. type Driver struct{}
  92. func (d Driver) Open(dsn string) (driver.Conn, error) {
  93. config, err := newBinlogConfig(dsn)
  94. if nil != err {
  95. return nil, err
  96. }
  97. c := newBinlogConn(config)
  98. var t interface{}
  99. dialer := net.Dialer{Timeout: c.Config.Timeout}
  100. addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
  101. t, err = dialer.Dial("tcp", addr)
  102. if err != nil {
  103. netErr, ok := err.(net.Error)
  104. if ok && !netErr.Temporary() {
  105. fmt.Printf("Error: %s", netErr.Error())
  106. return nil, err
  107. }
  108. } else {
  109. c.tcpConn = t.(*net.TCPConn)
  110. c.setConnection(t.(net.Conn))
  111. }
  112. err = c.decodeHandshakePacket()
  113. if err != nil {
  114. return nil, err
  115. }
  116. c.HandshakeResponse = c.NewHandshakeResponse()
  117. // If we are on SSL send SSL_Request packet now
  118. if c.Config.SSL {
  119. err = c.writeSSLRequestPacket()
  120. if err != nil {
  121. return nil, err
  122. }
  123. tlsConf := NewClientTLSConfig(
  124. c.Config.SSLKey,
  125. c.Config.SSLCer,
  126. []byte(c.Config.SSLCA),
  127. c.Config.VerifyCert,
  128. c.Config.Host,
  129. )
  130. c.secTCPConn = tls.Client(c.tcpConn, tlsConf)
  131. c.setConnection(c.secTCPConn)
  132. }
  133. err = c.writeHandshakeResponse()
  134. if err != nil {
  135. return nil, err
  136. }
  137. err = c.listen()
  138. return c, err
  139. }
  140. func (c *Conn) listen() error {
  141. ph, err := c.getPacketHeader()
  142. if err != nil {
  143. return err
  144. }
  145. c.sequenceId++
  146. switch ph.Status {
  147. case StatusAuth:
  148. fmt.Println("IN: AuthMoreDate PACKET")
  149. md, err := c.decodeAuthMoreDataResponsePacket(ph)
  150. if err != nil {
  151. return err
  152. }
  153. switch md.Data {
  154. case SHA2_FAST_AUTH_SUCCESS:
  155. fmt.Println("FAST AUTH")
  156. case SHA2_REQUEST_PUBLIC_KEY:
  157. fmt.Println("REQUEST PUBLIC KEY")
  158. case SHA2_PERFORM_FULL_AUTHENTICATION:
  159. fmt.Println("FULL AUTH")
  160. c.putBytes(append([]byte(c.Config.Pass), NullByte))
  161. if c.Flush() != nil {
  162. return c.Flush()
  163. }
  164. }
  165. case StatusEOF:
  166. fallthrough
  167. case StatusOK:
  168. fmt.Println("IN: OK PACKET")
  169. _, err := c.decodeOKPacket(ph)
  170. if err != nil {
  171. return err
  172. }
  173. case StatusErr:
  174. fmt.Println("IN: ERROR PACKET")
  175. ep, err := c.decodeErrorPacket(ph)
  176. if err != nil {
  177. return err
  178. }
  179. err = errors.New(fmt.Sprintf("Error %d: %s", ep.ErrorCode, ep.ErrorMessage))
  180. return err
  181. }
  182. err = c.scanner.Err()
  183. if err != nil {
  184. return err
  185. }
  186. err = c.listen() // Listen forever until we get an error.
  187. if err != nil {
  188. return err
  189. }
  190. return nil
  191. }
  192. type PacketHeader struct {
  193. Length uint64
  194. SequenceID uint64
  195. Status uint64
  196. }
  197. func (c *Conn) getPacketHeader() (PacketHeader, error) {
  198. ph := PacketHeader{}
  199. ph.Length = c.getInt(TypeFixedInt, 3)
  200. ph.SequenceID = c.getInt(TypeFixedInt, 1)
  201. ph.Status = c.getInt(TypeFixedInt, 1)
  202. err := c.scanner.Err()
  203. if err != nil {
  204. return ph, err
  205. }
  206. return ph, nil
  207. }
  208. func init() {
  209. sql.Register("mysql-binlog", &Driver{})
  210. }
  211. func (c *Conn) readBytes(l uint64) *bytes.Buffer {
  212. b := make([]byte, 0)
  213. for i := uint64(0); i < l; i++ {
  214. didScan := c.scanner.Scan()
  215. if !didScan {
  216. return nil
  217. }
  218. b = append(b, c.scanner.Bytes()...)
  219. }
  220. return bytes.NewBuffer(b)
  221. }
  222. func (c *Conn) getBytesUntilEOF() *bytes.Buffer {
  223. l := uint64(1)
  224. s := c.readBytes(l)
  225. b := s.Bytes()
  226. for true {
  227. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  228. break
  229. }
  230. s := c.readBytes(uint64(l))
  231. if s == nil {
  232. return bytes.NewBuffer(b)
  233. }
  234. b = append(b, s.Bytes()...)
  235. }
  236. return bytes.NewBuffer(b)
  237. }
  238. func (c *Conn) getBytesUntilNull() *bytes.Buffer {
  239. l := uint64(1)
  240. s := c.readBytes(l)
  241. b := s.Bytes()
  242. for true {
  243. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  244. break
  245. }
  246. s = c.readBytes(uint64(l))
  247. b = append(b, s.Bytes()...)
  248. }
  249. return bytes.NewBuffer(b)
  250. }
  251. func (c *Conn) discardBytes(l int) {
  252. for i := 0; i < l; i++ {
  253. c.scanner.Scan()
  254. }
  255. }
  256. func (c *Conn) getInt(t int, l uint64) uint64 {
  257. var v uint64
  258. switch t {
  259. case TypeFixedInt:
  260. v = c.decFixedInt(l)
  261. case TypeLenEncInt:
  262. v = c.decLenEncInt()
  263. default:
  264. v = 0
  265. }
  266. return v
  267. }
  268. func (c *Conn) getString(t int, l uint64) string {
  269. var v string
  270. switch t {
  271. case TypeFixedString:
  272. v = c.decFixedString(l)
  273. case TypeLenEncString:
  274. v = string(c.decLenEncInt())
  275. case TypeNullTerminatedString:
  276. v = c.decNullTerminatedString()
  277. case TypeRestOfPacketString:
  278. v = c.decRestOfPacketString()
  279. default:
  280. v = ""
  281. }
  282. return v
  283. }
  284. func (c *Conn) decRestOfPacketString() string {
  285. b := c.getBytesUntilEOF()
  286. return string(b.Bytes())
  287. }
  288. func (c *Conn) decNullTerminatedString() string {
  289. b := c.getBytesUntilNull()
  290. return strings.TrimRight(b.String(), string(NullByte))
  291. }
  292. func (c *Conn) decFixedString(l uint64) string {
  293. b := c.readBytes(l)
  294. return b.String()
  295. }
  296. func (c *Conn) decLenEncInt() uint64 {
  297. var l uint16
  298. b := c.readBytes(1)
  299. br := bytes.NewReader(b.Bytes())
  300. _ = binary.Read(br, binary.LittleEndian, &l)
  301. if l > 0 {
  302. return c.decFixedInt(uint64(l))
  303. } else {
  304. return 0
  305. }
  306. }
  307. func (c *Conn) decFixedInt(l uint64) uint64 {
  308. var i uint64
  309. b := c.readBytes(l)
  310. if l <= 2 {
  311. var x uint16
  312. pb := c.padBytes(2, b.Bytes())
  313. br := bytes.NewReader(pb)
  314. _ = binary.Read(br, binary.LittleEndian, &x)
  315. i = uint64(x)
  316. } else if l <= 4 {
  317. var x uint32
  318. pb := c.padBytes(4, b.Bytes())
  319. br := bytes.NewReader(pb)
  320. _ = binary.Read(br, binary.LittleEndian, &x)
  321. i = uint64(x)
  322. } else if l <= 8 {
  323. var x uint64
  324. pb := c.padBytes(8, b.Bytes())
  325. br := bytes.NewReader(pb)
  326. _ = binary.Read(br, binary.LittleEndian, &x)
  327. i = x
  328. }
  329. return i
  330. }
  331. func (c *Conn) padBytes(l int, b []byte) []byte {
  332. bl := len(b)
  333. pl := l - bl
  334. for i := 0; i < pl; i++ {
  335. b = append(b, NullByte)
  336. }
  337. return b
  338. }
  339. func (c *Conn) encFixedLenInt(v uint64, l uint64) []byte {
  340. b := make([]byte, 8)
  341. binary.LittleEndian.PutUint64(b, v)
  342. return b[:l]
  343. }
  344. func (c *Conn) encLenEncInt(v uint64) []byte {
  345. prefix := make([]byte, 1)
  346. var b []byte
  347. switch {
  348. case v < MaxUint08:
  349. b = make([]byte, 2)
  350. binary.LittleEndian.PutUint16(b, uint16(v))
  351. b = b[:1]
  352. case v >= MaxUint08 && v < MaxUint16:
  353. prefix[0] = 0xFC
  354. b = make([]byte, 3)
  355. binary.LittleEndian.PutUint16(b, uint16(v))
  356. b = b[:2]
  357. case v >= MaxUint16 && v < MaxUint24:
  358. prefix[0] = 0xFD
  359. b = make([]byte, 4)
  360. binary.LittleEndian.PutUint32(b, uint32(v))
  361. b = b[:3]
  362. case v >= MaxUint24 && v < MaxUint64:
  363. prefix[0] = 0xFE
  364. b = make([]byte, 9)
  365. binary.LittleEndian.PutUint64(b, uint64(v))
  366. }
  367. if len(b) > 1 {
  368. b = append(prefix, b...)
  369. }
  370. return b
  371. }
  372. func (c *Conn) bitmaskToStruct(b []byte, s interface{}) interface{} {
  373. l := len(b)
  374. t := reflect.TypeOf(s)
  375. v := reflect.New(t.Elem()).Elem()
  376. for i := uint(0); i < uint(v.NumField()); i++ {
  377. f := v.Field(int(i))
  378. var v bool
  379. switch {
  380. case l > 4:
  381. x := binary.LittleEndian.Uint64(b)
  382. flag := uint64(1 << i)
  383. v = x&flag > 0
  384. case l > 2:
  385. x := binary.LittleEndian.Uint32(b)
  386. flag := uint32(1 << i)
  387. v = x&flag > 0
  388. case l > 1:
  389. x := binary.LittleEndian.Uint16(b)
  390. flag := uint16(1 << i)
  391. v = x&flag > 0
  392. default:
  393. x := uint(b[0])
  394. flag := uint(1 << i)
  395. v = x&flag > 0
  396. }
  397. f.SetBool(v)
  398. }
  399. return v.Interface()
  400. }
  401. func (c *Conn) structToBitmask(s interface{}) []byte {
  402. t := reflect.TypeOf(s).Elem()
  403. sV := reflect.ValueOf(s).Elem()
  404. fC := uint(t.NumField())
  405. m := uint64(0)
  406. for i := uint(0); i < fC; i++ {
  407. f := sV.Field(int(i))
  408. v := f.Bool()
  409. if v {
  410. m |= 1 << i
  411. }
  412. }
  413. l := uint64(math.Ceil(float64(fC) / 8.0))
  414. b := make([]byte, 8)
  415. binary.LittleEndian.PutUint64(b, m)
  416. switch {
  417. case l > 4: // 64 bits
  418. b = b[:8]
  419. case l > 2: // 32 bits
  420. b = b[:4]
  421. case l > 1: // 16 bits
  422. b = b[:2]
  423. default: // 8 bits
  424. b = b[:1]
  425. }
  426. return b
  427. }
  428. func (c *Conn) putString(t int, v string) uint64 {
  429. b := make([]byte, 0)
  430. switch t {
  431. case TypeFixedString:
  432. b = c.encFixedString(v)
  433. case TypeNullTerminatedString:
  434. b = c.encNullTerminatedString(v)
  435. case TypeRestOfPacketString:
  436. b = c.encRestOfPacketString(v)
  437. }
  438. l, err := c.writeBuf.Write(b)
  439. if err != nil {
  440. c.err = err
  441. }
  442. return uint64(l)
  443. }
  444. func (c *Conn) encNullTerminatedString(v string) []byte {
  445. return append([]byte(v), NullByte)
  446. }
  447. func (c *Conn) encFixedString(v string) []byte {
  448. return []byte(v)
  449. }
  450. func (c *Conn) encRestOfPacketString(v string) []byte {
  451. s := c.encFixedString(v)
  452. return s
  453. }
  454. func (c *Conn) putInt(t int, v uint64, l uint64) uint64 {
  455. c.setupWriteBuffer()
  456. b := make([]byte, 0)
  457. switch t {
  458. case TypeFixedInt:
  459. b = c.encFixedLenInt(v, l)
  460. case TypeLenEncInt:
  461. b = c.encLenEncInt(v)
  462. }
  463. n, err := c.writeBuf.Write(b)
  464. if err != nil {
  465. c.err = err
  466. }
  467. return uint64(n)
  468. }
  469. func (c *Conn) putNullBytes(n uint64) uint64 {
  470. c.setupWriteBuffer()
  471. b := make([]byte, n)
  472. l, err := c.writeBuf.Write(b)
  473. if err != nil {
  474. c.err = err
  475. }
  476. return uint64(l)
  477. }
  478. func (c *Conn) putBytes(v []byte) uint64 {
  479. c.setupWriteBuffer()
  480. l, err := c.writeBuf.Write(v)
  481. if err != nil {
  482. c.err = err
  483. }
  484. return uint64(l)
  485. }
  486. func (c *Conn) Flush() error {
  487. if c.err != nil {
  488. return c.err
  489. }
  490. c.writeBuf = c.addHeader()
  491. _, _ = c.buffer.Write(c.writeBuf.Bytes())
  492. // log all outgoing packets
  493. // fmt.Printf(
  494. // "\nOUT:\n%08b\n%x\n%s\n\n",
  495. // c.writeBuf.Bytes(),
  496. // c.writeBuf.Bytes(),
  497. // c.writeBuf.Bytes(),
  498. // )
  499. if c.buffer.Flush() != nil {
  500. return c.buffer.Flush()
  501. }
  502. c.writeBuf = nil
  503. return nil
  504. }
  505. func (c *Conn) addHeader() *bytes.Buffer {
  506. pl := uint64(c.writeBuf.Len())
  507. sId := uint64(c.sequenceId)
  508. c.sequenceId++
  509. plB := c.encFixedLenInt(pl, 3)
  510. sIdB := c.encFixedLenInt(sId, 1)
  511. return bytes.NewBuffer(append(append(plB, sIdB...), c.writeBuf.Bytes()...))
  512. }
  513. func (c *Conn) setupWriteBuffer() {
  514. if c.writeBuf == nil {
  515. c.writeBuf = bytes.NewBuffer(nil)
  516. }
  517. }
  518. type StatusFlags struct {
  519. }
  520. type OKPacket struct {
  521. PacketHeader
  522. Header uint64
  523. AffectedRows uint64
  524. LastInsertID uint64
  525. StatusFlags uint64
  526. Warnings uint64
  527. Info string
  528. SessionStateInfo string
  529. }
  530. func (c *Conn) decodeOKPacket(ph PacketHeader) (*OKPacket, error) {
  531. op := OKPacket{}
  532. op.PacketHeader = ph
  533. op.Header = ph.Status
  534. op.AffectedRows = c.getInt(TypeLenEncInt, 0)
  535. op.LastInsertID = c.getInt(TypeLenEncInt, 0)
  536. if c.HandshakeResponse.ClientFlag.Protocol41 {
  537. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  538. op.Warnings = c.getInt(TypeFixedInt, 1)
  539. } else if c.HandshakeResponse.ClientFlag.Transactions {
  540. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  541. }
  542. if c.HandshakeResponse.ClientFlag.SessionTrack {
  543. op.Info = c.getString(TypeRestOfPacketString, 0)
  544. } else {
  545. op.Info = c.getString(TypeRestOfPacketString, 0)
  546. }
  547. return &op, nil
  548. }
  549. type ErrorPacket struct {
  550. PacketHeader
  551. ErrorCode uint64
  552. ErrorMessage string
  553. SQLStateMarker string
  554. SQLState string
  555. }
  556. func (c *Conn) decodeErrorPacket(ph PacketHeader) (*ErrorPacket, error) {
  557. ep := ErrorPacket{}
  558. ep.PacketHeader = ph
  559. ep.ErrorCode = c.getInt(TypeFixedInt, 2)
  560. ep.SQLStateMarker = c.getString(TypeFixedString, 1)
  561. ep.SQLState = c.getString(TypeFixedString, 5)
  562. ep.ErrorMessage = c.getString(TypeRestOfPacketString, 0)
  563. err := c.scanner.Err()
  564. if err != nil {
  565. return nil, err
  566. }
  567. return &ep, nil
  568. }
  569. func (c *Conn) setConnection(nc net.Conn) {
  570. c.curConn = nc
  571. c.buffer = bufio.NewReadWriter(
  572. bufio.NewReader(c.curConn),
  573. bufio.NewWriter(c.curConn),
  574. )
  575. c.scanner = bufio.NewScanner(c.buffer.Reader)
  576. c.scanner.Split(bufio.ScanBytes)
  577. }