connection.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. package binlog
  2. import (
  3. "bufio"
  4. "bytes"
  5. "crypto/tls"
  6. "database/sql"
  7. "database/sql/driver"
  8. "encoding/binary"
  9. "encoding/json"
  10. "fmt"
  11. "io/ioutil"
  12. "math"
  13. "net"
  14. "reflect"
  15. "strings"
  16. "time"
  17. )
  18. // Misc. Constants
  19. const NullByte byte = 0
  20. var EOF = bytes.NewBuffer([]byte{NullByte})
  21. const MaxPacketSize = MaxUint16
  22. // MySQL Packet Data Types
  23. const TypeNullTerminatedString = int(0)
  24. const TypeFixedString = int(1)
  25. const TypeFixedInt = int(2)
  26. const TypeLenEncInt = int(3)
  27. const TypeRestOfPacketString = int(4)
  28. const TypeLenEncString = int(5)
  29. // Integer Maximums
  30. const MaxUint08 = 1<<8 - 1
  31. const MaxUint16 = 1<<16 - 1
  32. const MaxUint24 = 1<<24 - 1
  33. const MaxUint64 = 1<<64 - 1
  34. // Packet Statuses
  35. const StatusOK = 0x00
  36. const StatusEOF = 0xFE
  37. const StatusErr = 0xFF
  38. const StatusAuth = 0x01
  39. type Config struct {
  40. Host string `json:"host"`
  41. Port int `json:"port"`
  42. User string `json:"user"`
  43. Pass string `json:"password"`
  44. Database string `json:"database"`
  45. SSL bool `json:"ssl"`
  46. SSLCA string `json:"ssl-ca"`
  47. SSLCer string `json:"ssl-cer"`
  48. SSLKey string `json:"ssl-key"`
  49. VerifyCert bool `json:"verify-cert"`
  50. ServerId uint64 `json:"server-id"`
  51. BinLogFile string `json:"binlog-file"`
  52. Timeout time.Duration
  53. }
  54. func newBinlogConfig(dsn string) (*Config, error) {
  55. var err error
  56. b, err := ioutil.ReadFile(dsn)
  57. if err != nil {
  58. return nil, err
  59. }
  60. config := Config{}
  61. err = json.Unmarshal(b, &config)
  62. return &config, err
  63. }
  64. type Conn struct {
  65. Config *Config
  66. curConn net.Conn
  67. tcpConn *net.TCPConn
  68. secTCPConn *tls.Conn
  69. Handshake *Handshake
  70. HandshakeResponse *HandshakeResponse
  71. buffer *bufio.ReadWriter
  72. scanner *bufio.Scanner
  73. err error
  74. sequenceId uint64
  75. writeBuf *bytes.Buffer
  76. StausFlags *StatusFlags
  77. Listener *net.Listener
  78. }
  79. func newBinlogConn(config *Config) Conn {
  80. return Conn{
  81. Config: config,
  82. sequenceId: 1,
  83. }
  84. }
  85. func (c Conn) Prepare(query string) (driver.Stmt, error) {
  86. return nil, nil
  87. }
  88. func (c Conn) Close() error {
  89. return nil
  90. }
  91. func (c Conn) Begin() (driver.Tx, error) {
  92. return nil, nil
  93. }
  94. type Driver struct{}
  95. func (d Driver) Open(dsn string) (driver.Conn, error) {
  96. config, err := newBinlogConfig(dsn)
  97. if nil != err {
  98. return nil, err
  99. }
  100. c := newBinlogConn(config)
  101. var t interface{}
  102. dialer := net.Dialer{Timeout: c.Config.Timeout}
  103. addr := fmt.Sprintf("%s:%d", c.Config.Host, c.Config.Port)
  104. t, err = dialer.Dial("tcp", addr)
  105. if err != nil {
  106. netErr, ok := err.(net.Error)
  107. if ok && !netErr.Temporary() {
  108. fmt.Printf("Error: %s", netErr.Error())
  109. return nil, err
  110. }
  111. } else {
  112. c.tcpConn = t.(*net.TCPConn)
  113. c.setConnection(t.(net.Conn))
  114. }
  115. err = c.decodeHandshakePacket()
  116. if err != nil {
  117. return nil, err
  118. }
  119. c.HandshakeResponse = c.NewHandshakeResponse()
  120. // If we are on SSL send SSL_Request packet now
  121. if c.Config.SSL {
  122. err = c.writeSSLRequestPacket()
  123. if err != nil {
  124. return nil, err
  125. }
  126. tlsConf := NewClientTLSConfig(
  127. c.Config.SSLKey,
  128. c.Config.SSLCer,
  129. []byte(c.Config.SSLCA),
  130. c.Config.VerifyCert,
  131. c.Config.Host,
  132. )
  133. c.secTCPConn = tls.Client(c.tcpConn, tlsConf)
  134. c.setConnection(c.secTCPConn)
  135. }
  136. err = c.writeHandshakeResponse()
  137. if err != nil {
  138. return nil, err
  139. }
  140. return c, err
  141. }
  142. type PacketHeader struct {
  143. Length uint64
  144. SequenceID uint64
  145. Status uint64
  146. }
  147. func (c *Conn) getPacketHeader() (PacketHeader, error) {
  148. ph := PacketHeader{}
  149. ph.Length = c.getInt(TypeFixedInt, 3)
  150. ph.SequenceID = c.getInt(TypeFixedInt, 1)
  151. ph.Status = c.getInt(TypeFixedInt, 1)
  152. err := c.scanner.Err()
  153. if err != nil {
  154. return ph, err
  155. }
  156. return ph, nil
  157. }
  158. func init() {
  159. sql.Register("mysql-binlog", &Driver{})
  160. }
  161. func (c *Conn) readBytes(l uint64) *bytes.Buffer {
  162. b := make([]byte, 0)
  163. for i := uint64(0); i < l; i++ {
  164. didScan := c.scanner.Scan()
  165. if !didScan {
  166. err := c.scanner.Err()
  167. if err == nil { // scanner reached EOF
  168. return EOF
  169. } else {
  170. panic(err) // @TODO Handle this gracefully.
  171. }
  172. return nil
  173. }
  174. b = append(b, c.scanner.Bytes()...)
  175. }
  176. return bytes.NewBuffer(b)
  177. }
  178. func (c *Conn) getBytesUntilEOF() *bytes.Buffer {
  179. l := uint64(1)
  180. s := c.readBytes(l)
  181. b := s.Bytes()
  182. for true {
  183. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  184. break
  185. }
  186. s := c.readBytes(uint64(l))
  187. if s == EOF || s == nil {
  188. return bytes.NewBuffer(b)
  189. }
  190. b = append(b, s.Bytes()...)
  191. }
  192. return bytes.NewBuffer(b)
  193. }
  194. func (c *Conn) getBytesUntilNull() *bytes.Buffer {
  195. l := uint64(1)
  196. s := c.readBytes(l)
  197. b := s.Bytes()
  198. for true {
  199. if uint64(s.Len()) != l || s.Bytes()[0] == NullByte {
  200. break
  201. }
  202. s = c.readBytes(uint64(l))
  203. b = append(b, s.Bytes()...)
  204. }
  205. return bytes.NewBuffer(b)
  206. }
  207. func (c *Conn) discardBytes(l uint64) {
  208. c.readBytes(l)
  209. }
  210. func (c *Conn) getInt(t int, l uint64) uint64 {
  211. var v uint64
  212. switch t {
  213. case TypeFixedInt:
  214. v = c.decFixedInt(l)
  215. case TypeLenEncInt:
  216. v = c.decLenEncInt()
  217. default:
  218. v = 0
  219. }
  220. return v
  221. }
  222. func (c *Conn) getString(t int, l uint64) string {
  223. var v string
  224. switch t {
  225. case TypeFixedString:
  226. v = c.decFixedString(l)
  227. case TypeLenEncString:
  228. v = string(c.decLenEncInt())
  229. case TypeNullTerminatedString:
  230. v = c.decNullTerminatedString()
  231. case TypeRestOfPacketString:
  232. v = c.decRestOfPacketString()
  233. default:
  234. v = ""
  235. }
  236. return v
  237. }
  238. func (c *Conn) decRestOfPacketString() string {
  239. b := c.getBytesUntilEOF()
  240. return string(b.Bytes())
  241. }
  242. func (c *Conn) decNullTerminatedString() string {
  243. b := c.getBytesUntilNull()
  244. return strings.TrimRight(b.String(), string(NullByte))
  245. }
  246. func (c *Conn) decFixedString(l uint64) string {
  247. b := c.readBytes(l)
  248. return b.String()
  249. }
  250. func (c *Conn) decLenEncInt() uint64 {
  251. var l uint16
  252. b := c.readBytes(1)
  253. br := bytes.NewReader(b.Bytes())
  254. _ = binary.Read(br, binary.LittleEndian, &l)
  255. if l > 0 {
  256. return c.decFixedInt(uint64(l))
  257. } else {
  258. return 0
  259. }
  260. }
  261. func (c *Conn) decFixedInt(l uint64) uint64 {
  262. var i uint64
  263. b := c.readBytes(l)
  264. if l <= 2 {
  265. var x uint16
  266. pb := c.padBytes(2, b.Bytes())
  267. br := bytes.NewReader(pb)
  268. _ = binary.Read(br, binary.LittleEndian, &x)
  269. i = uint64(x)
  270. } else if l <= 4 {
  271. var x uint32
  272. pb := c.padBytes(4, b.Bytes())
  273. br := bytes.NewReader(pb)
  274. _ = binary.Read(br, binary.LittleEndian, &x)
  275. i = uint64(x)
  276. } else if l <= 8 {
  277. var x uint64
  278. pb := c.padBytes(8, b.Bytes())
  279. br := bytes.NewReader(pb)
  280. _ = binary.Read(br, binary.LittleEndian, &x)
  281. i = x
  282. }
  283. return i
  284. }
  285. func (c *Conn) padBytes(l int, b []byte) []byte {
  286. bl := len(b)
  287. pl := l - bl
  288. for i := 0; i < pl; i++ {
  289. b = append(b, NullByte)
  290. }
  291. return b
  292. }
  293. func (c *Conn) encFixedLenInt(v uint64, l uint64) []byte {
  294. b := make([]byte, 8)
  295. binary.LittleEndian.PutUint64(b, v)
  296. return b[:l]
  297. }
  298. func (c *Conn) encLenEncInt(v uint64) []byte {
  299. prefix := make([]byte, 1)
  300. var b []byte
  301. switch {
  302. case v < MaxUint08:
  303. b = make([]byte, 2)
  304. binary.LittleEndian.PutUint16(b, uint16(v))
  305. b = b[:1]
  306. case v >= MaxUint08 && v < MaxUint16:
  307. prefix[0] = 0xFC
  308. b = make([]byte, 3)
  309. binary.LittleEndian.PutUint16(b, uint16(v))
  310. b = b[:2]
  311. case v >= MaxUint16 && v < MaxUint24:
  312. prefix[0] = 0xFD
  313. b = make([]byte, 4)
  314. binary.LittleEndian.PutUint32(b, uint32(v))
  315. b = b[:3]
  316. case v >= MaxUint24 && v < MaxUint64:
  317. prefix[0] = 0xFE
  318. b = make([]byte, 9)
  319. binary.LittleEndian.PutUint64(b, uint64(v))
  320. }
  321. if len(b) > 1 {
  322. b = append(prefix, b...)
  323. }
  324. return b
  325. }
  326. func (c *Conn) bitmaskToStruct(b []byte, s interface{}) interface{} {
  327. l := len(b)
  328. t := reflect.TypeOf(s)
  329. v := reflect.New(t.Elem()).Elem()
  330. for i := uint(0); i < uint(v.NumField()); i++ {
  331. f := v.Field(int(i))
  332. var v bool
  333. switch {
  334. case l > 4:
  335. x := binary.LittleEndian.Uint64(b)
  336. flag := uint64(1 << i)
  337. v = x&flag > 0
  338. case l > 2:
  339. x := binary.LittleEndian.Uint32(b)
  340. flag := uint32(1 << i)
  341. v = x&flag > 0
  342. case l > 1:
  343. x := binary.LittleEndian.Uint16(b)
  344. flag := uint16(1 << i)
  345. v = x&flag > 0
  346. default:
  347. x := uint(b[0])
  348. flag := uint(1 << i)
  349. v = x&flag > 0
  350. }
  351. f.SetBool(v)
  352. }
  353. return v.Interface()
  354. }
  355. func (c *Conn) structToBitmask(s interface{}) []byte {
  356. t := reflect.TypeOf(s).Elem()
  357. sV := reflect.ValueOf(s).Elem()
  358. fC := uint(t.NumField())
  359. m := uint64(0)
  360. for i := uint(0); i < fC; i++ {
  361. f := sV.Field(int(i))
  362. v := f.Bool()
  363. if v {
  364. m |= 1 << i
  365. }
  366. }
  367. l := uint64(math.Ceil(float64(fC) / 8.0))
  368. b := make([]byte, 8)
  369. binary.LittleEndian.PutUint64(b, m)
  370. switch {
  371. case l > 4: // 64 bits
  372. b = b[:8]
  373. case l > 2: // 32 bits
  374. b = b[:4]
  375. case l > 1: // 16 bits
  376. b = b[:2]
  377. default: // 8 bits
  378. b = b[:1]
  379. }
  380. return b
  381. }
  382. func (c *Conn) putString(t int, v string) uint64 {
  383. b := make([]byte, 0)
  384. switch t {
  385. case TypeFixedString:
  386. b = c.encFixedString(v)
  387. case TypeNullTerminatedString:
  388. b = c.encNullTerminatedString(v)
  389. case TypeRestOfPacketString:
  390. b = c.encRestOfPacketString(v)
  391. }
  392. l, err := c.writeBuf.Write(b)
  393. if err != nil {
  394. c.err = err
  395. }
  396. return uint64(l)
  397. }
  398. func (c *Conn) encNullTerminatedString(v string) []byte {
  399. return append([]byte(v), NullByte)
  400. }
  401. func (c *Conn) encFixedString(v string) []byte {
  402. return []byte(v)
  403. }
  404. func (c *Conn) encRestOfPacketString(v string) []byte {
  405. s := c.encFixedString(v)
  406. return s
  407. }
  408. func (c *Conn) putInt(t int, v uint64, l uint64) uint64 {
  409. c.setupWriteBuffer()
  410. b := make([]byte, 0)
  411. switch t {
  412. case TypeFixedInt:
  413. b = c.encFixedLenInt(v, l)
  414. case TypeLenEncInt:
  415. b = c.encLenEncInt(v)
  416. }
  417. n, err := c.writeBuf.Write(b)
  418. if err != nil {
  419. c.err = err
  420. }
  421. return uint64(n)
  422. }
  423. func (c *Conn) putNullBytes(n uint64) uint64 {
  424. c.setupWriteBuffer()
  425. b := make([]byte, n)
  426. l, err := c.writeBuf.Write(b)
  427. if err != nil {
  428. c.err = err
  429. }
  430. return uint64(l)
  431. }
  432. func (c *Conn) putBytes(v []byte) uint64 {
  433. c.setupWriteBuffer()
  434. l, err := c.writeBuf.Write(v)
  435. if err != nil {
  436. c.err = err
  437. }
  438. return uint64(l)
  439. }
  440. func (c *Conn) Flush() error {
  441. if c.err != nil {
  442. return c.err
  443. }
  444. c.writeBuf = c.addHeader()
  445. _, _ = c.buffer.Write(c.writeBuf.Bytes())
  446. if c.buffer.Flush() != nil {
  447. return c.buffer.Flush()
  448. }
  449. c.writeBuf = nil
  450. return nil
  451. }
  452. func (c *Conn) addHeader() *bytes.Buffer {
  453. pl := uint64(c.writeBuf.Len())
  454. sId := uint64(c.sequenceId)
  455. c.sequenceId++
  456. plB := c.encFixedLenInt(pl, 3)
  457. sIdB := c.encFixedLenInt(sId, 1)
  458. return bytes.NewBuffer(append(append(plB, sIdB...), c.writeBuf.Bytes()...))
  459. }
  460. func (c *Conn) setupWriteBuffer() {
  461. if c.writeBuf == nil {
  462. c.writeBuf = bytes.NewBuffer(nil)
  463. }
  464. }
  465. type StatusFlags struct {
  466. }
  467. type OKPacket struct {
  468. PacketHeader
  469. Header uint64
  470. AffectedRows uint64
  471. LastInsertID uint64
  472. StatusFlags uint64
  473. Warnings uint64
  474. Info string
  475. SessionStateInfo string
  476. }
  477. func (c *Conn) decodeOKPacket(ph PacketHeader) (*OKPacket, error) {
  478. op := OKPacket{}
  479. op.PacketHeader = ph
  480. op.Header = ph.Status
  481. op.AffectedRows = c.getInt(TypeLenEncInt, 0)
  482. op.LastInsertID = c.getInt(TypeLenEncInt, 0)
  483. if c.HandshakeResponse.ClientFlag.Protocol41 {
  484. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  485. op.Warnings = c.getInt(TypeFixedInt, 1)
  486. } else if c.HandshakeResponse.ClientFlag.Transactions {
  487. op.StatusFlags = c.getInt(TypeFixedInt, 2)
  488. }
  489. if c.HandshakeResponse.ClientFlag.SessionTrack {
  490. op.Info = c.getString(TypeRestOfPacketString, 0)
  491. } else {
  492. op.Info = c.getString(TypeRestOfPacketString, 0)
  493. }
  494. return &op, nil
  495. }
  496. type ErrorPacket struct {
  497. PacketHeader
  498. ErrorCode uint64
  499. ErrorMessage string
  500. SQLStateMarker string
  501. SQLState string
  502. }
  503. func (c *Conn) decodeErrorPacket(ph PacketHeader) (*ErrorPacket, error) {
  504. ep := ErrorPacket{}
  505. ep.PacketHeader = ph
  506. ep.ErrorCode = c.getInt(TypeFixedInt, 2)
  507. ep.SQLStateMarker = c.getString(TypeFixedString, 1)
  508. ep.SQLState = c.getString(TypeFixedString, 5)
  509. ep.ErrorMessage = c.getString(TypeRestOfPacketString, 0)
  510. err := c.scanner.Err()
  511. if err != nil {
  512. return nil, err
  513. }
  514. return &ep, nil
  515. }
  516. func (c *Conn) setConnection(nc net.Conn) {
  517. c.curConn = nc
  518. c.buffer = bufio.NewReadWriter(
  519. bufio.NewReader(c.curConn),
  520. bufio.NewWriter(c.curConn),
  521. )
  522. c.scanner = bufio.NewScanner(c.buffer.Reader)
  523. c.scanner.Split(bufio.ScanBytes)
  524. }