handshake.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. package binlog
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "fmt"
  7. "io/ioutil"
  8. )
  9. type Capabilities struct {
  10. LongPassword bool
  11. FoundRows bool
  12. LongFlag bool
  13. ConnectWithDB bool
  14. NoSchema bool
  15. Compress bool
  16. ODBC bool
  17. LocalFiles bool
  18. IgnoreSpace bool
  19. Protocol41 bool
  20. Interactive bool
  21. SSL bool
  22. IgnoreSigpipe bool
  23. Transactions bool
  24. LegacyProtocol41 bool
  25. SecureConnection bool
  26. MultiStatements bool
  27. MultiResults bool
  28. PSMultiResults bool
  29. PluginAuth bool
  30. ConnectAttrs bool
  31. PluginAuthLenEncClientData bool
  32. CanHandleExpiredPasswords bool
  33. SessionTrack bool
  34. DeprecateEOF bool
  35. SSLVerifyServerCert bool
  36. OptionalResultSetMetadata bool
  37. RememberOptions bool
  38. }
  39. type Status struct {
  40. InTrans bool
  41. Autocommit bool
  42. MoreResultsExists bool
  43. QueryNoGoodIndexUsed bool
  44. QueryNoIndexUsed bool
  45. CursorExists bool
  46. LastRowSent bool
  47. DBDropped bool
  48. NoBackslashEscapes bool
  49. MetadataChanged bool
  50. QueryWasSlow bool
  51. PSOutParams bool
  52. InTransReadonly bool
  53. SessionStateChanged bool
  54. }
  55. type Handshake struct {
  56. PacketLength uint64
  57. SequenceID uint64
  58. ProtocolVersion uint64
  59. ServerVersion string
  60. ThreadID uint64
  61. AuthPluginDataPart1 *bytes.Buffer
  62. CapabilityFlags1 *bytes.Buffer
  63. Charset uint64
  64. StatusFlags *bytes.Buffer
  65. CapabilityFlags2 *bytes.Buffer
  66. AuthPluginDataLength uint64
  67. AuthPluginDataPart2 *bytes.Buffer
  68. AuthPluginName string
  69. Capabilities *Capabilities
  70. Status *Status
  71. }
  72. type HandshakeResponse struct {
  73. ClientFlag *Capabilities
  74. MaxPacketSize uint64
  75. CharacterSet uint64
  76. Username string
  77. AuthResponseLength uint64
  78. AuthResponse string
  79. Database string
  80. ClientPluginName string
  81. KeyValues map[string]string
  82. }
  83. type SSLRequest struct {
  84. ClientFlag *Capabilities
  85. MaxPacketSize uint64
  86. CharacterSet uint64
  87. Username string
  88. }
  89. func (c *Conn) decodeCapabilityFlags(hs *Handshake) {
  90. var cfb = append(hs.CapabilityFlags1.Bytes(), hs.CapabilityFlags2.Bytes()...)
  91. capabilities := c.bitmaskToStruct(cfb, hs.Capabilities).(Capabilities)
  92. hs.Capabilities = &capabilities
  93. }
  94. func (c *Conn) decodeStatusFlags(hs *Handshake) {
  95. status := c.bitmaskToStruct(hs.StatusFlags.Bytes(), hs.Status).(Status)
  96. hs.Status = &status
  97. }
  98. func (c *Conn) decodeHandshakePacket() error {
  99. packet := Handshake{}
  100. packet.PacketLength = c.getInt(TypeFixedInt, 3)
  101. packet.SequenceID = c.getInt(TypeFixedInt, 1)
  102. packet.ProtocolVersion = c.getInt(TypeFixedInt, 1)
  103. packet.ServerVersion = c.getString(TypeNullTerminatedString, 0)
  104. packet.ThreadID = c.getInt(TypeFixedInt, 4)
  105. packet.AuthPluginDataPart1 = c.readBytes(8)
  106. c.discardBytes(1)
  107. packet.CapabilityFlags1 = c.readBytes(2)
  108. packet.Charset = c.getInt(TypeFixedInt, 1)
  109. packet.StatusFlags = c.readBytes(2)
  110. c.decodeStatusFlags(&packet)
  111. packet.CapabilityFlags2 = c.readBytes(2)
  112. c.decodeCapabilityFlags(&packet)
  113. packet.AuthPluginDataLength = c.getInt(TypeFixedInt, 1)
  114. c.discardBytes(10)
  115. p1l := uint64(packet.AuthPluginDataPart1.Len())
  116. packet.AuthPluginDataPart2 = c.readBytes(packet.AuthPluginDataLength - p1l)
  117. packet.AuthPluginName = c.getString(TypeNullTerminatedString, 0)
  118. err := c.scanner.Err()
  119. if err != nil {
  120. return err
  121. }
  122. c.Handshake = &packet
  123. return nil
  124. }
  125. func (c *Conn) writeHandshakeResponse() error {
  126. hr := c.HandshakeResponse
  127. cf := c.structToBitmask(hr.ClientFlag)
  128. c.putBytes(cf)
  129. c.putInt(TypeFixedInt, hr.MaxPacketSize, 4)
  130. c.putInt(TypeFixedInt, hr.CharacterSet, 1)
  131. c.putNullBytes(23)
  132. c.putString(TypeNullTerminatedString, hr.Username)
  133. // Perform authentication
  134. salt := append(c.Handshake.AuthPluginDataPart1.Bytes(), c.Handshake.AuthPluginDataPart2.Bytes()...)
  135. password := []byte(hr.AuthResponse)
  136. c.authenticate(salt, password)
  137. // Write database name
  138. if hr.ClientFlag.ConnectWithDB {
  139. c.putString(TypeNullTerminatedString, hr.Database)
  140. }
  141. // Set type of auth plugin based on if it is at the end of the packet.
  142. var t int
  143. if hr.KeyValues != nil {
  144. t = TypeNullTerminatedString
  145. } else {
  146. t = TypeRestOfPacketString
  147. }
  148. // Write auth plugin
  149. if hr.ClientFlag.PluginAuth {
  150. c.putString(t, hr.ClientPluginName)
  151. c.putNullBytes(1)
  152. }
  153. if c.Flush() != nil {
  154. return c.Flush()
  155. }
  156. return nil
  157. }
  158. func (c *Conn) writeSSLRequestPacket() error {
  159. sr := c.NewSSLRequest()
  160. cf := c.structToBitmask(sr.ClientFlag)
  161. c.putBytes(cf)
  162. c.putInt(TypeFixedInt, sr.MaxPacketSize, 4)
  163. c.putInt(TypeFixedInt, sr.CharacterSet, 1)
  164. c.putNullBytes(23)
  165. if c.Flush() != nil {
  166. return c.Flush()
  167. }
  168. return nil
  169. }
  170. func (c *Conn) NewSSLRequest() *SSLRequest {
  171. return &SSLRequest{
  172. ClientFlag: c.HandshakeResponse.ClientFlag,
  173. MaxPacketSize: c.HandshakeResponse.MaxPacketSize,
  174. CharacterSet: c.HandshakeResponse.CharacterSet,
  175. Username: c.HandshakeResponse.Username,
  176. }
  177. }
  178. func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
  179. return &HandshakeResponse{
  180. ClientFlag: &Capabilities{
  181. LongPassword: true,
  182. FoundRows: true,
  183. LongFlag: false,
  184. ConnectWithDB: false,
  185. NoSchema: false,
  186. Compress: false,
  187. ODBC: false,
  188. LocalFiles: false,
  189. IgnoreSpace: true,
  190. Protocol41: true,
  191. Interactive: true,
  192. SSL: c.Config.SSL,
  193. IgnoreSigpipe: false,
  194. Transactions: true,
  195. LegacyProtocol41: false,
  196. SecureConnection: true,
  197. MultiStatements: false,
  198. MultiResults: false,
  199. PSMultiResults: true,
  200. PluginAuth: true,
  201. ConnectAttrs: false,
  202. PluginAuthLenEncClientData: false,
  203. CanHandleExpiredPasswords: false,
  204. SessionTrack: true,
  205. DeprecateEOF: false,
  206. SSLVerifyServerCert: c.Config.VerifyCert,
  207. OptionalResultSetMetadata: false,
  208. RememberOptions: false,
  209. },
  210. MaxPacketSize: MaxPacketSize,
  211. CharacterSet: 45,
  212. Username: c.Config.User,
  213. AuthResponseLength: 0,
  214. AuthResponse: c.Config.Pass,
  215. Database: c.Config.Database,
  216. ClientPluginName: c.Handshake.AuthPluginName,
  217. KeyValues: nil,
  218. }
  219. }
  220. // generate TLS config for client side
  221. // if insecureSkipVerify is set to true, serverName will not be validated
  222. func NewClientTLSConfig(keyPem string, cerPem string, caPem []byte, insecureSkipVerify bool, serverName string) *tls.Config {
  223. fmt.Printf("insecureSkipVerify = %+v\n", insecureSkipVerify)
  224. config := &tls.Config{
  225. InsecureSkipVerify: !insecureSkipVerify,
  226. ServerName: serverName,
  227. }
  228. if caPem != nil {
  229. ca, err := ioutil.ReadFile(string(caPem))
  230. if err == nil {
  231. pool := x509.NewCertPool()
  232. if !pool.AppendCertsFromPEM(ca) {
  233. panic("failed to add ca PEM")
  234. }
  235. config.RootCAs = pool
  236. }
  237. }
  238. if keyPem != "" && cerPem != "" {
  239. cert, err := tls.LoadX509KeyPair(cerPem, keyPem)
  240. if err != nil {
  241. panic(err)
  242. }
  243. config.Certificates = []tls.Certificate{cert}
  244. }
  245. return config
  246. }