handshake.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. package binlog
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "io/ioutil"
  7. )
  8. // Capabilities represents a MySQL protocol bit array for communicating the capabilities of the server or client.
  9. type Capabilities struct {
  10. LongPassword bool
  11. FoundRows bool
  12. LongFlag bool
  13. ConnectWithDB bool
  14. NoSchema bool
  15. Compress bool
  16. ODBC bool
  17. LocalFiles bool
  18. IgnoreSpace bool
  19. Protocol41 bool
  20. Interactive bool
  21. SSL bool
  22. IgnoreSigpipe bool
  23. Transactions bool
  24. LegacyProtocol41 bool
  25. SecureConnection bool
  26. MultiStatements bool
  27. MultiResults bool
  28. PSMultiResults bool
  29. PluginAuth bool
  30. ConnectAttrs bool
  31. PluginAuthLenEncClientData bool
  32. CanHandleExpiredPasswords bool
  33. SessionTrack bool
  34. DeprecateEOF bool
  35. SSLVerifyServerCert bool
  36. OptionalResultSetMetadata bool
  37. RememberOptions bool
  38. }
  39. // Status represents a bit array data structure from the MySQL protocol that indicates the server status.
  40. type Status struct {
  41. InTrans bool
  42. Autocommit bool
  43. MoreResultsExists bool
  44. QueryNoGoodIndexUsed bool
  45. QueryNoIndexUsed bool
  46. CursorExists bool
  47. LastRowSent bool
  48. DBDropped bool
  49. NoBackslashEscapes bool
  50. MetadataChanged bool
  51. QueryWasSlow bool
  52. PSOutParams bool
  53. InTransReadonly bool
  54. SessionStateChanged bool
  55. }
  56. // Handshake represents the Handshake packet from the MySQL protocol.
  57. type Handshake struct {
  58. PacketLength uint64
  59. SequenceID uint64
  60. ProtocolVersion uint64
  61. ServerVersion string
  62. ThreadID uint64
  63. AuthPluginDataPart1 *bytes.Buffer
  64. CapabilityFlags1 *bytes.Buffer
  65. Charset uint64
  66. StatusFlags *bytes.Buffer
  67. CapabilityFlags2 *bytes.Buffer
  68. AuthPluginDataLength uint64
  69. AuthPluginDataPart2 *bytes.Buffer
  70. AuthPluginName string
  71. Capabilities *Capabilities
  72. Status *Status
  73. }
  74. // HandshakeResponse represents the Handshake Response packet from the MySQL protocol.
  75. type HandshakeResponse struct {
  76. ClientFlag *Capabilities
  77. MaxPacketSize uint64
  78. CharacterSet uint64
  79. Username string
  80. AuthResponseLength uint64
  81. AuthResponse string
  82. Database string
  83. ClientPluginName string
  84. KeyValues map[string]string
  85. }
  86. // SSLRequest represents the SSL Request packet from the MySQL protocol.
  87. type SSLRequest struct {
  88. ClientFlag *Capabilities
  89. MaxPacketSize uint64
  90. CharacterSet uint64
  91. Username string
  92. }
  93. func (c *Conn) decodeCapabilityFlags(hs *Handshake) {
  94. var cfb = append(hs.CapabilityFlags1.Bytes(), hs.CapabilityFlags2.Bytes()...)
  95. capabilities := c.bitmaskToStruct(cfb, hs.Capabilities).(Capabilities)
  96. hs.Capabilities = &capabilities
  97. }
  98. func (c *Conn) decodeStatusFlags(hs *Handshake) {
  99. status := c.bitmaskToStruct(hs.StatusFlags.Bytes(), hs.Status).(Status)
  100. hs.Status = &status
  101. }
  102. func (c *Conn) decodeHandshakePacket() error {
  103. packet := Handshake{}
  104. packet.PacketLength = c.getInt(TypeFixedInt, 3)
  105. packet.SequenceID = c.getInt(TypeFixedInt, 1)
  106. packet.ProtocolVersion = c.getInt(TypeFixedInt, 1)
  107. packet.ServerVersion = c.getString(TypeNullTerminatedString, 0)
  108. packet.ThreadID = c.getInt(TypeFixedInt, 4)
  109. packet.AuthPluginDataPart1 = c.readBytes(8)
  110. c.discardBytes(1)
  111. packet.CapabilityFlags1 = c.readBytes(2)
  112. packet.Charset = c.getInt(TypeFixedInt, 1)
  113. packet.StatusFlags = c.readBytes(2)
  114. c.decodeStatusFlags(&packet)
  115. packet.CapabilityFlags2 = c.readBytes(2)
  116. c.decodeCapabilityFlags(&packet)
  117. packet.AuthPluginDataLength = c.getInt(TypeFixedInt, 1)
  118. c.discardBytes(10)
  119. p1l := uint64(packet.AuthPluginDataPart1.Len())
  120. packet.AuthPluginDataPart2 = c.readBytes(packet.AuthPluginDataLength - p1l)
  121. packet.AuthPluginName = c.getString(TypeNullTerminatedString, 0)
  122. err := c.scanner.Err()
  123. if err != nil {
  124. return err
  125. }
  126. c.Handshake = &packet
  127. return nil
  128. }
  129. func (c *Conn) writeHandshakeResponse() error {
  130. hr := c.HandshakeResponse
  131. cf := c.structToBitmask(hr.ClientFlag)
  132. c.putBytes(cf)
  133. c.putInt(TypeFixedInt, hr.MaxPacketSize, 4)
  134. c.putInt(TypeFixedInt, hr.CharacterSet, 1)
  135. c.putNullBytes(23)
  136. c.putString(TypeNullTerminatedString, hr.Username)
  137. // Perform authentication
  138. salt := append(c.Handshake.AuthPluginDataPart1.Bytes(), c.Handshake.AuthPluginDataPart2.Bytes()...)
  139. password := []byte(hr.AuthResponse)
  140. c.authenticate(salt, password)
  141. // Write database name
  142. if hr.ClientFlag.ConnectWithDB {
  143. c.putString(TypeNullTerminatedString, hr.Database)
  144. }
  145. // Set type of auth plugin based on if it is at the end of the packet.
  146. var t int
  147. if hr.KeyValues != nil {
  148. t = TypeNullTerminatedString
  149. } else {
  150. t = TypeRestOfPacketString
  151. }
  152. // Write auth plugin
  153. if hr.ClientFlag.PluginAuth {
  154. c.putString(t, hr.ClientPluginName)
  155. c.putNullBytes(1)
  156. }
  157. if c.Flush() != nil {
  158. return c.Flush()
  159. }
  160. return nil
  161. }
  162. func (c *Conn) writeSSLRequestPacket() error {
  163. sr := c.NewSSLRequest()
  164. cf := c.structToBitmask(sr.ClientFlag)
  165. c.putBytes(cf)
  166. c.putInt(TypeFixedInt, sr.MaxPacketSize, 4)
  167. c.putInt(TypeFixedInt, sr.CharacterSet, 1)
  168. c.putNullBytes(23)
  169. if c.Flush() != nil {
  170. return c.Flush()
  171. }
  172. return nil
  173. }
  174. // NewSSLRequest creates a new SSL Request packet using information from the handshake response.
  175. func (c *Conn) NewSSLRequest() *SSLRequest {
  176. return &SSLRequest{
  177. ClientFlag: c.HandshakeResponse.ClientFlag,
  178. MaxPacketSize: c.HandshakeResponse.MaxPacketSize,
  179. CharacterSet: c.HandshakeResponse.CharacterSet,
  180. Username: c.HandshakeResponse.Username,
  181. }
  182. }
  183. // NewHandshakeResponse creates the default handshake response packet.
  184. func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
  185. return &HandshakeResponse{
  186. ClientFlag: &Capabilities{
  187. LongPassword: true,
  188. FoundRows: true,
  189. LongFlag: false,
  190. ConnectWithDB: false,
  191. NoSchema: false,
  192. Compress: false,
  193. ODBC: false,
  194. LocalFiles: false,
  195. IgnoreSpace: true,
  196. Protocol41: true,
  197. Interactive: true,
  198. SSL: c.Config.SSL,
  199. IgnoreSigpipe: false,
  200. Transactions: c.Handshake.Capabilities.Transactions,
  201. LegacyProtocol41: false,
  202. SecureConnection: true,
  203. MultiStatements: false,
  204. MultiResults: false,
  205. PSMultiResults: true,
  206. PluginAuth: c.Handshake.Capabilities.PluginAuth,
  207. ConnectAttrs: false,
  208. PluginAuthLenEncClientData: false,
  209. CanHandleExpiredPasswords: false,
  210. SessionTrack: c.Handshake.Capabilities.SessionTrack,
  211. DeprecateEOF: false,
  212. SSLVerifyServerCert: c.Config.VerifyCert,
  213. OptionalResultSetMetadata: false,
  214. RememberOptions: false,
  215. },
  216. MaxPacketSize: MaxPacketSize,
  217. CharacterSet: 45,
  218. Username: c.Config.User,
  219. AuthResponseLength: 0,
  220. AuthResponse: c.Config.Pass,
  221. Database: c.Config.Database,
  222. ClientPluginName: c.Handshake.AuthPluginName,
  223. KeyValues: nil,
  224. }
  225. }
  226. // NewClientTLSConfig generates TLS config for client side if insecureSkipVerify is set to true, serverName will not be validated
  227. func NewClientTLSConfig(keyPem string, cerPem string, caPem []byte, insecureSkipVerify bool, serverName string) *tls.Config {
  228. config := &tls.Config{
  229. InsecureSkipVerify: !insecureSkipVerify,
  230. ServerName: serverName,
  231. }
  232. if caPem != nil {
  233. ca, err := ioutil.ReadFile(string(caPem))
  234. if err == nil {
  235. pool := x509.NewCertPool()
  236. if !pool.AppendCertsFromPEM(ca) {
  237. panic("failed to add ca PEM")
  238. }
  239. config.RootCAs = pool
  240. }
  241. }
  242. if keyPem != "" && cerPem != "" {
  243. cert, err := tls.LoadX509KeyPair(cerPem, keyPem)
  244. if err != nil {
  245. panic(err)
  246. }
  247. config.Certificates = []tls.Certificate{cert}
  248. }
  249. return config
  250. }