connection.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. package binlog
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "database/sql"
  7. "database/sql/driver"
  8. "encoding/binary"
  9. "encoding/json"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "io/ioutil"
  14. "math"
  15. "net"
  16. "reflect"
  17. "strings"
  18. "time"
  19. )
  20. // MySQL Packet Data Types
  21. const TypeNullTerminatedString = int(0)
  22. const TypeFixedString = int(1)
  23. const TypeFixedInt = int(2)
  24. const TypeLenEncInt = int(3)
  25. const TypeRestOfPacketString = int(4)
  26. // Integer Maximums
  27. const MaxUint08 = 1<<8 - 1
  28. const MaxUint16 = 1<<16 - 1
  29. const MaxUint24 = 1<<24 - 1
  30. const MaxUint64 = 1<<64 - 1
  31. // Misc. Constants
  32. const NullByte byte = 0
  33. const MaxPacketSize = MaxUint16
  34. type Config struct {
  35. Host string `json:"host"`
  36. Port int `json:"port"`
  37. User string `json:"user"`
  38. Pass string `json:"password"`
  39. Database string `json:"database"`
  40. SSL bool `json:"ssl"`
  41. SSLCA string `json:"ssl-ca"`
  42. SSLCer string `json:"ssl-cer"`
  43. SSLKey string `json:"ssl-key"`
  44. VerifyCert bool `json:"verify-cert"`
  45. Timeout time.Duration
  46. }
  47. func splitByBytesFunc(data []byte, atEOF bool) (advance int, token []byte, err error) {
  48. if atEOF {
  49. return 0, nil, io.EOF
  50. }
  51. return 1, data[:1], nil
  52. }
  53. func newBinlogConfig(dsn string) (*Config, error) {
  54. var err error
  55. b, err := ioutil.ReadFile(dsn)
  56. if err != nil {
  57. return nil, err
  58. }
  59. config := Config{}
  60. err = json.Unmarshal(b, &config)
  61. return &config, err
  62. }
  63. type Conn struct {
  64. Config *Config
  65. curConn net.Conn
  66. tcpConn *net.TCPConn
  67. secTCPConn *tls.Conn
  68. Handshake *Handshake
  69. HandshakeResponse *HandshakeResponse
  70. buffer *bufio.ReadWriter
  71. scanner *bufio.Scanner
  72. err error
  73. sequenceId uint64
  74. writeBuf *bytes.Buffer
  75. }
  76. func newBinlogConn(config *Config) Conn {
  77. return Conn{
  78. Config: config,
  79. sequenceId: 1,
  80. }
  81. }
  82. func (c Conn) Prepare(query string) (driver.Stmt, error) {
  83. return nil, nil
  84. }
  85. func (c Conn) Close() error {
  86. return nil
  87. }
  88. func (c Conn) Begin() (driver.Tx, error) {
  89. return nil, nil
  90. }
  91. type Driver struct{}
  92. func (d Driver) Open(dsn string) (driver.Conn, error) {
  93. config, err := newBinlogConfig(dsn)
  94. if nil != err {
  95. return nil, err
  96. }
  97. c := newBinlogConn(config)
  98. var t interface{}
  99. dialer := net.Dialer{Timeout: c.Config.Timeout}
  100. addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
  101. t, err = dialer.Dial("tcp", addr)
  102. if err != nil {
  103. netErr, ok := err.(net.Error)
  104. if ok && !netErr.Temporary() {
  105. fmt.Printf("Error: %s", netErr.Error())
  106. return nil, err
  107. }
  108. } else {
  109. c.tcpConn = t.(*net.TCPConn)
  110. c.setConnection(t.(net.Conn))
  111. }
  112. err = c.decodeHandshakePacket()
  113. if err != nil {
  114. return nil, err
  115. }
  116. c.HandshakeResponse = c.NewHandshakeResponse()
  117. // Send SSL_Request Packet
  118. if c.Config.SSL {
  119. err = c.writeSSLRequestPacket()
  120. if err != nil {
  121. return nil, err
  122. }
  123. tlsConf := NewClientTLSConfig(
  124. c.Config.SSLKey,
  125. c.Config.SSLCer,
  126. []byte(c.Config.SSLCA),
  127. c.Config.VerifyCert,
  128. c.Config.Host,
  129. )
  130. c.secTCPConn = tls.Client(c.tcpConn, tlsConf)
  131. c.setConnection(c.secTCPConn)
  132. }
  133. err = c.writeHandshakeResponse()
  134. if err != nil {
  135. return nil, err
  136. }
  137. err = c.listen()
  138. return c, err
  139. }
  140. func (c *Conn) listen() error {
  141. fmt.Println("LISTEN")
  142. ph, err := c.getPacketHeader()
  143. if err != nil {
  144. return err
  145. }
  146. switch ph.Status {
  147. case 0x01:
  148. fmt.Println("IN: AuthMoreDate PACKET")
  149. _, err := c.decodeAuthMoreDataResponsePacket(ph)
  150. if err != nil {
  151. return err
  152. }
  153. case 0x00:
  154. fmt.Println("IN: OK PACKET")
  155. case 0xFE:
  156. fmt.Println("IN: EOF PACKET")
  157. case 0xFF:
  158. fmt.Println("IN: ERROR PACKET")
  159. ep, err := c.decodeErrorPacket(ph)
  160. if err != nil {
  161. return err
  162. }
  163. err = errors.New(fmt.Sprintf("Error %d: %s", ep.ErrorCode, ep.ErrorMessage))
  164. return err
  165. default:
  166. fmt.Printf("ph = %+v\n", ph)
  167. fmt.Println("IN: UNKNOWN PACKET")
  168. }
  169. err = c.scanner.Err()
  170. if err != nil {
  171. return err
  172. }
  173. err = c.listen() // Listen forever until we get an error.
  174. if err != nil {
  175. return err
  176. }
  177. return nil
  178. }
  179. type PacketHeader struct {
  180. Length uint64
  181. SequenceID uint64
  182. Status uint64
  183. }
  184. func (c *Conn) getPacketHeader() (PacketHeader, error) {
  185. ph := PacketHeader{}
  186. ph.Length = c.getInt(TypeFixedInt, 3)
  187. ph.SequenceID = c.getInt(TypeFixedInt, 1)
  188. ph.Status = c.getInt(TypeFixedInt, 1)
  189. err := c.scanner.Err()
  190. if err != nil {
  191. return ph, err
  192. }
  193. return ph, nil
  194. }
  195. func init() {
  196. sql.Register("mysql-binlog", &Driver{})
  197. }
  198. func (c *Conn) readBytes(l uint64) *bytes.Buffer {
  199. b := make([]byte, 0)
  200. for i := uint64(0); i < l; i++ {
  201. c.scanner.Scan()
  202. b = append(b, c.scanner.Bytes()...)
  203. }
  204. return bytes.NewBuffer(b)
  205. }
  206. func (c *Conn) getBytesUntilEOF() *bytes.Buffer {
  207. l := uint64(1)
  208. s := c.readBytes(l)
  209. b := s.Bytes()
  210. for true {
  211. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  212. break
  213. }
  214. s = c.readBytes(uint64(l))
  215. err := c.scanner.Err()
  216. if err == io.EOF {
  217. return bytes.NewBuffer(b)
  218. } else if err != nil {
  219. panic(err)
  220. }
  221. b = append(b, s.Bytes()...)
  222. }
  223. return bytes.NewBuffer(b)
  224. }
  225. func (c *Conn) getBytesUntilNull() *bytes.Buffer {
  226. l := uint64(1)
  227. s := c.readBytes(l)
  228. b := s.Bytes()
  229. for true {
  230. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  231. break
  232. }
  233. s = c.readBytes(uint64(l))
  234. b = append(b, s.Bytes()...)
  235. }
  236. return bytes.NewBuffer(b)
  237. }
  238. func (c *Conn) discardBytes(l int) {
  239. for i := 0; i < l; i++ {
  240. c.scanner.Scan()
  241. }
  242. }
  243. func (c *Conn) getInt(t int, l uint64) uint64 {
  244. var v uint64
  245. switch t {
  246. case TypeFixedInt:
  247. v = c.decFixedInt(l)
  248. default:
  249. v = 0
  250. }
  251. return v
  252. }
  253. func (c *Conn) getString(t int, l uint64) string {
  254. var v string
  255. switch t {
  256. case TypeFixedString:
  257. v = c.decFixedString(l)
  258. case TypeNullTerminatedString:
  259. v = c.decNullTerminatedString()
  260. case TypeRestOfPacketString:
  261. v = c.decRestOfPacketString()
  262. default:
  263. v = ""
  264. }
  265. return v
  266. }
  267. func (c *Conn) decRestOfPacketString() string {
  268. b := c.getBytesUntilEOF()
  269. return string(b.Bytes())
  270. }
  271. func (c *Conn) decNullTerminatedString() string {
  272. b := c.getBytesUntilNull()
  273. return strings.TrimRight(b.String(), string(NullByte))
  274. }
  275. func (c *Conn) decFixedString(l uint64) string {
  276. b := c.readBytes(l)
  277. return b.String()
  278. }
  279. func (c *Conn) decFixedInt(l uint64) uint64 {
  280. var i uint64
  281. b := c.readBytes(l)
  282. if l <= 2 {
  283. var x uint16
  284. pb := c.padBytes(2, b.Bytes())
  285. br := bytes.NewReader(pb)
  286. _ = binary.Read(br, binary.LittleEndian, &x)
  287. i = uint64(x)
  288. } else if l <= 4 {
  289. var x uint32
  290. pb := c.padBytes(4, b.Bytes())
  291. br := bytes.NewReader(pb)
  292. _ = binary.Read(br, binary.LittleEndian, &x)
  293. i = uint64(x)
  294. } else if l <= 8 {
  295. var x uint64
  296. pb := c.padBytes(8, b.Bytes())
  297. br := bytes.NewReader(pb)
  298. _ = binary.Read(br, binary.LittleEndian, &x)
  299. i = x
  300. }
  301. return i
  302. }
  303. func (c *Conn) padBytes(l int, b []byte) []byte {
  304. bl := len(b)
  305. pl := l - bl
  306. for i := 0; i < pl; i++ {
  307. b = append(b, NullByte)
  308. }
  309. return b
  310. }
  311. func (c *Conn) encFixedLenInt(v uint64, l uint64) []byte {
  312. b := make([]byte, 8)
  313. binary.LittleEndian.PutUint64(b, v)
  314. return b[:l]
  315. }
  316. func (c *Conn) encLenEncInt(v uint64) []byte {
  317. prefix := make([]byte, 1)
  318. var b []byte
  319. switch {
  320. case v < MaxUint08:
  321. b = make([]byte, 2)
  322. binary.LittleEndian.PutUint16(b, uint16(v))
  323. b = b[:1]
  324. case v >= MaxUint08 && v < MaxUint16:
  325. prefix[0] = 0xFC
  326. b = make([]byte, 3)
  327. binary.LittleEndian.PutUint16(b, uint16(v))
  328. b = b[:2]
  329. case v >= MaxUint16 && v < MaxUint24:
  330. prefix[0] = 0xFD
  331. b = make([]byte, 4)
  332. binary.LittleEndian.PutUint32(b, uint32(v))
  333. b = b[:3]
  334. case v >= MaxUint24 && v < MaxUint64:
  335. prefix[0] = 0xFE
  336. b = make([]byte, 9)
  337. binary.LittleEndian.PutUint64(b, uint64(v))
  338. }
  339. if len(b) > 1 {
  340. b = append(prefix, b...)
  341. }
  342. return b
  343. }
  344. func (c *Conn) bitmaskToStruct(b []byte, s interface{}) interface{} {
  345. l := len(b)
  346. t := reflect.TypeOf(s)
  347. v := reflect.New(t.Elem()).Elem()
  348. for i := uint(0); i < uint(v.NumField()); i++ {
  349. f := v.Field(int(i))
  350. var v bool
  351. switch {
  352. case l > 4:
  353. x := binary.LittleEndian.Uint64(b)
  354. flag := uint64(1 << i)
  355. v = x&flag > 0
  356. case l > 2:
  357. x := binary.LittleEndian.Uint32(b)
  358. flag := uint32(1 << i)
  359. v = x&flag > 0
  360. case l > 1:
  361. x := binary.LittleEndian.Uint16(b)
  362. flag := uint16(1 << i)
  363. v = x&flag > 0
  364. default:
  365. x := uint(b[0])
  366. flag := uint(1 << i)
  367. v = x&flag > 0
  368. }
  369. f.SetBool(v)
  370. }
  371. return v.Interface()
  372. }
  373. func (c *Conn) structToBitmask(s interface{}) []byte {
  374. t := reflect.TypeOf(s).Elem()
  375. sV := reflect.ValueOf(s).Elem()
  376. fC := uint(t.NumField())
  377. m := uint64(0)
  378. for i := uint(0); i < fC; i++ {
  379. f := sV.Field(int(i))
  380. v := f.Bool()
  381. if v {
  382. m |= 1 << i
  383. }
  384. }
  385. l := uint64(math.Ceil(float64(fC) / 8.0))
  386. b := make([]byte, 8)
  387. binary.LittleEndian.PutUint64(b, m)
  388. switch {
  389. case l > 4: // 64 bits
  390. b = b[:8]
  391. case l > 2: // 32 bits
  392. b = b[:4]
  393. case l > 1: // 16 bits
  394. b = b[:2]
  395. default: // 8 bits
  396. b = b[:1]
  397. }
  398. return b
  399. }
  400. func (c *Conn) putString(t int, v string) uint64 {
  401. b := make([]byte, 0)
  402. switch t {
  403. case TypeFixedString:
  404. b = c.encFixedString(v)
  405. case TypeNullTerminatedString:
  406. b = c.encNullTerminatedString(v)
  407. case TypeRestOfPacketString:
  408. b = c.encRestOfPacketString(v)
  409. }
  410. l, err := c.writeBuf.Write(b)
  411. if err != nil {
  412. c.err = err
  413. }
  414. return uint64(l)
  415. }
  416. func (c *Conn) encNullTerminatedString(v string) []byte {
  417. return append([]byte(v), NullByte)
  418. }
  419. func (c *Conn) encFixedString(v string) []byte {
  420. return []byte(v)
  421. }
  422. func (c *Conn) encRestOfPacketString(v string) []byte {
  423. s := c.encFixedString(v)
  424. return s
  425. }
  426. func (c *Conn) putInt(t int, v uint64, l uint64) uint64 {
  427. c.setupWriteBuffer()
  428. b := make([]byte, 0)
  429. switch t {
  430. case TypeFixedInt:
  431. b = c.encFixedLenInt(v, l)
  432. case TypeLenEncInt:
  433. b = c.encLenEncInt(v)
  434. }
  435. n, err := c.writeBuf.Write(b)
  436. if err != nil {
  437. c.err = err
  438. }
  439. return uint64(n)
  440. }
  441. func (c *Conn) putNullBytes(n uint64) uint64 {
  442. c.setupWriteBuffer()
  443. b := make([]byte, n)
  444. l, err := c.writeBuf.Write(b)
  445. if err != nil {
  446. c.err = err
  447. }
  448. return uint64(l)
  449. }
  450. func (c *Conn) putBytes(v []byte) uint64 {
  451. c.setupWriteBuffer()
  452. l, err := c.writeBuf.Write(v)
  453. if err != nil {
  454. c.err = err
  455. }
  456. return uint64(l)
  457. }
  458. func (c *Conn) Flush() error {
  459. if c.err != nil {
  460. return c.err
  461. }
  462. c.writeBuf = c.addHeader()
  463. _, _ = c.buffer.Write(c.writeBuf.Bytes())
  464. // log all packets
  465. fmt.Printf(
  466. "\nOUT:\n%08b\n%x\n%s\n\n",
  467. c.writeBuf.Bytes(),
  468. c.writeBuf.Bytes(),
  469. c.writeBuf.Bytes(),
  470. )
  471. if c.buffer.Flush() != nil {
  472. return c.buffer.Flush()
  473. }
  474. c.writeBuf = nil
  475. return nil
  476. }
  477. func (c *Conn) addHeader() *bytes.Buffer {
  478. pl := uint64(c.writeBuf.Len())
  479. sId := uint64(c.sequenceId)
  480. c.sequenceId++
  481. plB := c.encFixedLenInt(pl, 3)
  482. sIdB := c.encFixedLenInt(sId, 1)
  483. return bytes.NewBuffer(append(append(plB, sIdB...), c.writeBuf.Bytes()...))
  484. }
  485. func (c *Conn) setupWriteBuffer() {
  486. if c.writeBuf == nil {
  487. c.writeBuf = bytes.NewBuffer(nil)
  488. }
  489. }
  490. type ErrorPacket struct {
  491. PacketHeader
  492. ErrorCode uint64
  493. ErrorMessage string
  494. SQLStateMarker string
  495. SQLState string
  496. }
  497. func (c *Conn) decodeErrorPacket(ph PacketHeader) (*ErrorPacket, error) {
  498. ep := ErrorPacket{}
  499. ep.PacketHeader = ph
  500. ep.ErrorCode = c.getInt(TypeFixedInt, 2)
  501. ep.SQLStateMarker = c.getString(TypeFixedString, 1)
  502. ep.SQLState = c.getString(TypeFixedString, 5)
  503. ep.ErrorMessage = c.getString(TypeRestOfPacketString, 0)
  504. err := c.scanner.Err()
  505. if err != nil {
  506. return nil, err
  507. }
  508. return &ep, nil
  509. }
  510. func (c *Conn) setConnection(nc net.Conn) {
  511. c.curConn = nc
  512. c.buffer = bufio.NewReadWriter(
  513. bufio.NewReader(c.curConn),
  514. bufio.NewWriter(c.curConn),
  515. )
  516. c.scanner = bufio.NewScanner(c.buffer.Reader)
  517. c.scanner.Split(splitByBytesFunc)
  518. }