connection.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. package binlog
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "database/sql"
  7. "database/sql/driver"
  8. "encoding/binary"
  9. "encoding/json"
  10. "errors"
  11. "fmt"
  12. "io/ioutil"
  13. "math"
  14. "net"
  15. "reflect"
  16. "strings"
  17. "time"
  18. )
  19. // Misc. Constants
  20. const NullByte byte = 0
  21. var EOF = bytes.NewBuffer([]byte{NullByte})
  22. const MaxPacketSize = MaxUint16
  23. // MySQL Packet Data Types
  24. const TypeNullTerminatedString = int(0)
  25. const TypeFixedString = int(1)
  26. const TypeFixedInt = int(2)
  27. const TypeLenEncInt = int(3)
  28. const TypeRestOfPacketString = int(4)
  29. const TypeLenEncString = int(5)
  30. // Integer Maximums
  31. const MaxUint08 = 1<<8 - 1
  32. const MaxUint16 = 1<<16 - 1
  33. const MaxUint24 = 1<<24 - 1
  34. const MaxUint64 = 1<<64 - 1
  35. // Packet Statuses
  36. const StatusOK = 0x00
  37. const StatusEOF = 0xFE
  38. const StatusErr = 0xFF
  39. const StatusAuth = 0x01
  40. type Config struct {
  41. Host string `json:"host"`
  42. Port int `json:"port"`
  43. User string `json:"user"`
  44. Pass string `json:"password"`
  45. Database string `json:"database"`
  46. SSL bool `json:"ssl"`
  47. SSLCA string `json:"ssl-ca"`
  48. SSLCer string `json:"ssl-cer"`
  49. SSLKey string `json:"ssl-key"`
  50. VerifyCert bool `json:"verify-cert"`
  51. ServerId uint64 `json:"server-id"`
  52. BinLogFile string `json:"binlog-file"`
  53. Timeout time.Duration
  54. }
  55. func newBinlogConfig(dsn string) (*Config, error) {
  56. var err error
  57. b, err := ioutil.ReadFile(dsn)
  58. if err != nil {
  59. return nil, err
  60. }
  61. config := Config{}
  62. err = json.Unmarshal(b, &config)
  63. return &config, err
  64. }
  65. type Conn struct {
  66. Config *Config
  67. curConn net.Conn
  68. tcpConn *net.TCPConn
  69. secTCPConn *tls.Conn
  70. Handshake *Handshake
  71. HandshakeResponse *HandshakeResponse
  72. buffer *bufio.ReadWriter
  73. scanner *bufio.Scanner
  74. err error
  75. sequenceId uint64
  76. writeBuf *bytes.Buffer
  77. StausFlags *StatusFlags
  78. Listener *net.Listener
  79. }
  80. func newBinlogConn(config *Config) Conn {
  81. return Conn{
  82. Config: config,
  83. sequenceId: 1,
  84. }
  85. }
  86. func (c Conn) Prepare(query string) (driver.Stmt, error) {
  87. return nil, nil
  88. }
  89. func (c Conn) Close() error {
  90. return nil
  91. }
  92. func (c Conn) Begin() (driver.Tx, error) {
  93. return nil, nil
  94. }
  95. type Driver struct{}
  96. func (d Driver) Open(dsn string) (driver.Conn, error) {
  97. config, err := newBinlogConfig(dsn)
  98. if nil != err {
  99. return nil, err
  100. }
  101. c := newBinlogConn(config)
  102. var t interface{}
  103. dialer := net.Dialer{Timeout: c.Config.Timeout}
  104. addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
  105. t, err = dialer.Dial("tcp", addr)
  106. if err != nil {
  107. netErr, ok := err.(net.Error)
  108. if ok && !netErr.Temporary() {
  109. fmt.Printf("Error: %s", netErr.Error())
  110. return nil, err
  111. }
  112. } else {
  113. c.tcpConn = t.(*net.TCPConn)
  114. c.setConnection(t.(net.Conn))
  115. }
  116. err = c.decodeHandshakePacket()
  117. if err != nil {
  118. return nil, err
  119. }
  120. c.HandshakeResponse = c.NewHandshakeResponse()
  121. // If we are on SSL send SSL_Request packet now
  122. if c.Config.SSL {
  123. err = c.writeSSLRequestPacket()
  124. if err != nil {
  125. return nil, err
  126. }
  127. tlsConf := NewClientTLSConfig(
  128. c.Config.SSLKey,
  129. c.Config.SSLCer,
  130. []byte(c.Config.SSLCA),
  131. c.Config.VerifyCert,
  132. c.Config.Host,
  133. )
  134. c.secTCPConn = tls.Client(c.tcpConn, tlsConf)
  135. c.setConnection(c.secTCPConn)
  136. }
  137. err = c.writeHandshakeResponse()
  138. if err != nil {
  139. return nil, err
  140. }
  141. return c, err
  142. }
  143. type PacketHeader struct {
  144. Length uint64
  145. SequenceID uint64
  146. Status uint64
  147. }
  148. func (c *Conn) getPacketHeader() (PacketHeader, error) {
  149. ph := PacketHeader{}
  150. ph.Length = c.getInt(TypeFixedInt, 3)
  151. ph.SequenceID = c.getInt(TypeFixedInt, 1)
  152. ph.Status = c.getInt(TypeFixedInt, 1)
  153. err := c.scanner.Err()
  154. if err != nil {
  155. return ph, err
  156. }
  157. return ph, nil
  158. }
  159. func init() {
  160. sql.Register("mysql-binlog", &Driver{})
  161. }
  162. func (c *Conn) readBytes(l uint64) *bytes.Buffer {
  163. b := make([]byte, 0)
  164. for i := uint64(0); i < l; i++ {
  165. didScan := c.scanner.Scan()
  166. if !didScan {
  167. err := c.scanner.Err()
  168. if err == nil { // scanner reached EOF
  169. return EOF
  170. } else {
  171. panic(err) // @TODO Handle this gracefully.
  172. }
  173. return nil
  174. }
  175. b = append(b, c.scanner.Bytes()...)
  176. }
  177. return bytes.NewBuffer(b)
  178. }
  179. func (c *Conn) getBytesUntilEOF() *bytes.Buffer {
  180. l := uint64(1)
  181. s := c.readBytes(l)
  182. b := s.Bytes()
  183. for true {
  184. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  185. break
  186. }
  187. s := c.readBytes(uint64(l))
  188. if s == EOF || s == nil {
  189. return bytes.NewBuffer(b)
  190. }
  191. b = append(b, s.Bytes()...)
  192. }
  193. return bytes.NewBuffer(b)
  194. }
  195. func (c *Conn) getBytesUntilNull() *bytes.Buffer {
  196. l := uint64(1)
  197. s := c.readBytes(l)
  198. b := s.Bytes()
  199. for true {
  200. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  201. break
  202. }
  203. s = c.readBytes(uint64(l))
  204. b = append(b, s.Bytes()...)
  205. }
  206. return bytes.NewBuffer(b)
  207. }
  208. func (c *Conn) discardBytes(l uint64) {
  209. c.readBytes(l)
  210. }
  211. func (c *Conn) getInt(t int, l uint64) uint64 {
  212. var v uint64
  213. switch t {
  214. case TypeFixedInt:
  215. v = c.decFixedInt(l)
  216. case TypeLenEncInt:
  217. v = c.decLenEncInt()
  218. default:
  219. v = 0
  220. }
  221. return v
  222. }
  223. func (c *Conn) getString(t int, l uint64) string {
  224. var v string
  225. switch t {
  226. case TypeFixedString:
  227. v = c.decFixedString(l)
  228. case TypeLenEncString:
  229. v = string(c.decLenEncInt())
  230. case TypeNullTerminatedString:
  231. v = c.decNullTerminatedString()
  232. case TypeRestOfPacketString:
  233. v = c.decRestOfPacketString()
  234. default:
  235. v = ""
  236. }
  237. return v
  238. }
  239. func (c *Conn) decRestOfPacketString() string {
  240. b := c.getBytesUntilEOF()
  241. return string(b.Bytes())
  242. }
  243. func (c *Conn) decNullTerminatedString() string {
  244. b := c.getBytesUntilNull()
  245. return strings.TrimRight(b.String(), string(NullByte))
  246. }
  247. func (c *Conn) decFixedString(l uint64) string {
  248. b := c.readBytes(l)
  249. return b.String()
  250. }
  251. func (c *Conn) decLenEncInt() uint64 {
  252. var l uint16
  253. b := c.readBytes(1)
  254. br := bytes.NewReader(b.Bytes())
  255. _ = binary.Read(br, binary.LittleEndian, &l)
  256. if l > 0 {
  257. return c.decFixedInt(uint64(l))
  258. } else {
  259. return 0
  260. }
  261. }
  262. func (c *Conn) decFixedInt(l uint64) uint64 {
  263. var i uint64
  264. b := c.readBytes(l)
  265. if l <= 2 {
  266. var x uint16
  267. pb := c.padBytes(2, b.Bytes())
  268. br := bytes.NewReader(pb)
  269. _ = binary.Read(br, binary.LittleEndian, &x)
  270. i = uint64(x)
  271. } else if l <= 4 {
  272. var x uint32
  273. pb := c.padBytes(4, b.Bytes())
  274. br := bytes.NewReader(pb)
  275. _ = binary.Read(br, binary.LittleEndian, &x)
  276. i = uint64(x)
  277. } else if l <= 8 {
  278. var x uint64
  279. pb := c.padBytes(8, b.Bytes())
  280. br := bytes.NewReader(pb)
  281. _ = binary.Read(br, binary.LittleEndian, &x)
  282. i = x
  283. }
  284. return i
  285. }
  286. func (c *Conn) padBytes(l int, b []byte) []byte {
  287. bl := len(b)
  288. pl := l - bl
  289. for i := 0; i < pl; i++ {
  290. b = append(b, NullByte)
  291. }
  292. return b
  293. }
  294. func (c *Conn) encFixedLenInt(v uint64, l uint64) []byte {
  295. b := make([]byte, 8)
  296. binary.LittleEndian.PutUint64(b, v)
  297. return b[:l]
  298. }
  299. func (c *Conn) encLenEncInt(v uint64) []byte {
  300. prefix := make([]byte, 1)
  301. var b []byte
  302. switch {
  303. case v < MaxUint08:
  304. b = make([]byte, 2)
  305. binary.LittleEndian.PutUint16(b, uint16(v))
  306. b = b[:1]
  307. case v >= MaxUint08 && v < MaxUint16:
  308. prefix[0] = 0xFC
  309. b = make([]byte, 3)
  310. binary.LittleEndian.PutUint16(b, uint16(v))
  311. b = b[:2]
  312. case v >= MaxUint16 && v < MaxUint24:
  313. prefix[0] = 0xFD
  314. b = make([]byte, 4)
  315. binary.LittleEndian.PutUint32(b, uint32(v))
  316. b = b[:3]
  317. case v >= MaxUint24 && v < MaxUint64:
  318. prefix[0] = 0xFE
  319. b = make([]byte, 9)
  320. binary.LittleEndian.PutUint64(b, uint64(v))
  321. }
  322. if len(b) > 1 {
  323. b = append(prefix, b...)
  324. }
  325. return b
  326. }
  327. func (c *Conn) bitmaskToStruct(b []byte, s interface{}) interface{} {
  328. l := len(b)
  329. t := reflect.TypeOf(s)
  330. v := reflect.New(t.Elem()).Elem()
  331. for i := uint(0); i < uint(v.NumField()); i++ {
  332. f := v.Field(int(i))
  333. var v bool
  334. switch {
  335. case l > 4:
  336. x := binary.LittleEndian.Uint64(b)
  337. flag := uint64(1 << i)
  338. v = x&flag > 0
  339. case l > 2:
  340. x := binary.LittleEndian.Uint32(b)
  341. flag := uint32(1 << i)
  342. v = x&flag > 0
  343. case l > 1:
  344. x := binary.LittleEndian.Uint16(b)
  345. flag := uint16(1 << i)
  346. v = x&flag > 0
  347. default:
  348. x := uint(b[0])
  349. flag := uint(1 << i)
  350. v = x&flag > 0
  351. }
  352. f.SetBool(v)
  353. }
  354. return v.Interface()
  355. }
  356. func (c *Conn) structToBitmask(s interface{}) []byte {
  357. t := reflect.TypeOf(s).Elem()
  358. sV := reflect.ValueOf(s).Elem()
  359. fC := uint(t.NumField())
  360. m := uint64(0)
  361. for i := uint(0); i < fC; i++ {
  362. f := sV.Field(int(i))
  363. v := f.Bool()
  364. if v {
  365. m |= 1 << i
  366. }
  367. }
  368. l := uint64(math.Ceil(float64(fC) / 8.0))
  369. b := make([]byte, 8)
  370. binary.LittleEndian.PutUint64(b, m)
  371. switch {
  372. case l > 4: // 64 bits
  373. b = b[:8]
  374. case l > 2: // 32 bits
  375. b = b[:4]
  376. case l > 1: // 16 bits
  377. b = b[:2]
  378. default: // 8 bits
  379. b = b[:1]
  380. }
  381. return b
  382. }
  383. func (c *Conn) putString(t int, v string) uint64 {
  384. b := make([]byte, 0)
  385. switch t {
  386. case TypeFixedString:
  387. b = c.encFixedString(v)
  388. case TypeNullTerminatedString:
  389. b = c.encNullTerminatedString(v)
  390. case TypeRestOfPacketString:
  391. b = c.encRestOfPacketString(v)
  392. }
  393. l, err := c.writeBuf.Write(b)
  394. if err != nil {
  395. c.err = err
  396. }
  397. return uint64(l)
  398. }
  399. func (c *Conn) encNullTerminatedString(v string) []byte {
  400. return append([]byte(v), NullByte)
  401. }
  402. func (c *Conn) encFixedString(v string) []byte {
  403. return []byte(v)
  404. }
  405. func (c *Conn) encRestOfPacketString(v string) []byte {
  406. s := c.encFixedString(v)
  407. return s
  408. }
  409. func (c *Conn) putInt(t int, v uint64, l uint64) uint64 {
  410. c.setupWriteBuffer()
  411. b := make([]byte, 0)
  412. switch t {
  413. case TypeFixedInt:
  414. b = c.encFixedLenInt(v, l)
  415. case TypeLenEncInt:
  416. b = c.encLenEncInt(v)
  417. }
  418. n, err := c.writeBuf.Write(b)
  419. if err != nil {
  420. c.err = err
  421. }
  422. return uint64(n)
  423. }
  424. func (c *Conn) putNullBytes(n uint64) uint64 {
  425. c.setupWriteBuffer()
  426. b := make([]byte, n)
  427. l, err := c.writeBuf.Write(b)
  428. if err != nil {
  429. c.err = err
  430. }
  431. return uint64(l)
  432. }
  433. func (c *Conn) putBytes(v []byte) uint64 {
  434. c.setupWriteBuffer()
  435. l, err := c.writeBuf.Write(v)
  436. if err != nil {
  437. c.err = err
  438. }
  439. return uint64(l)
  440. }
  441. func (c *Conn) Flush() error {
  442. if c.err != nil {
  443. return c.err
  444. }
  445. c.writeBuf = c.addHeader()
  446. _, _ = c.buffer.Write(c.writeBuf.Bytes())
  447. if c.buffer.Flush() != nil {
  448. return c.buffer.Flush()
  449. }
  450. c.writeBuf = nil
  451. return nil
  452. }
  453. func (c *Conn) addHeader() *bytes.Buffer {
  454. pl := uint64(c.writeBuf.Len())
  455. sId := uint64(c.sequenceId)
  456. c.sequenceId++
  457. plB := c.encFixedLenInt(pl, 3)
  458. sIdB := c.encFixedLenInt(sId, 1)
  459. return bytes.NewBuffer(append(append(plB, sIdB...), c.writeBuf.Bytes()...))
  460. }
  461. func (c *Conn) setupWriteBuffer() {
  462. if c.writeBuf == nil {
  463. c.writeBuf = bytes.NewBuffer(nil)
  464. }
  465. }
  466. type StatusFlags struct {
  467. }
  468. type OKPacket struct {
  469. PacketHeader
  470. Header uint64
  471. AffectedRows uint64
  472. LastInsertID uint64
  473. StatusFlags uint64
  474. Warnings uint64
  475. Info string
  476. SessionStateInfo string
  477. }
  478. func (c *Conn) decodeOKPacket(ph PacketHeader) (*OKPacket, error) {
  479. op := OKPacket{}
  480. op.PacketHeader = ph
  481. op.Header = ph.Status
  482. op.AffectedRows = c.getInt(TypeLenEncInt, 0)
  483. op.LastInsertID = c.getInt(TypeLenEncInt, 0)
  484. if c.HandshakeResponse.ClientFlag.Protocol41 {
  485. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  486. op.Warnings = c.getInt(TypeFixedInt, 1)
  487. } else if c.HandshakeResponse.ClientFlag.Transactions {
  488. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  489. }
  490. if c.HandshakeResponse.ClientFlag.SessionTrack {
  491. op.Info = c.getString(TypeRestOfPacketString, 0)
  492. } else {
  493. op.Info = c.getString(TypeRestOfPacketString, 0)
  494. }
  495. return &op, nil
  496. }
  497. type ErrorPacket struct {
  498. PacketHeader
  499. ErrorCode uint64
  500. ErrorMessage string
  501. SQLStateMarker string
  502. SQLState string
  503. }
  504. func (c *Conn) decodeErrorPacket(ph PacketHeader) (*ErrorPacket, error) {
  505. ep := ErrorPacket{}
  506. ep.PacketHeader = ph
  507. ep.ErrorCode = c.getInt(TypeFixedInt, 2)
  508. ep.SQLStateMarker = c.getString(TypeFixedString, 1)
  509. ep.SQLState = c.getString(TypeFixedString, 5)
  510. ep.ErrorMessage = c.getString(TypeRestOfPacketString, 0)
  511. err := c.scanner.Err()
  512. if err != nil {
  513. return nil, err
  514. }
  515. return &ep, nil
  516. }
  517. func (c *Conn) setConnection(nc net.Conn) {
  518. c.curConn = nc
  519. c.buffer = bufio.NewReadWriter(
  520. bufio.NewReader(c.curConn),
  521. bufio.NewWriter(c.curConn),
  522. )
  523. c.scanner = bufio.NewScanner(c.buffer.Reader)
  524. c.scanner.Split(bufio.ScanBytes)
  525. }