handshake.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package binlog
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "io/ioutil"
  7. )
  8. type Capabilities struct {
  9. LongPassword bool
  10. FoundRows bool
  11. LongFlag bool
  12. ConnectWithDB bool
  13. NoSchema bool
  14. Compress bool
  15. ODBC bool
  16. LocalFiles bool
  17. IgnoreSpace bool
  18. Protocol41 bool
  19. Interactive bool
  20. SSL bool
  21. IgnoreSigpipe bool
  22. Transactions bool
  23. LegacyProtocol41 bool
  24. SecureConnection bool
  25. MultiStatements bool
  26. MultiResults bool
  27. PSMultiResults bool
  28. PluginAuth bool
  29. ConnectAttrs bool
  30. PluginAuthLenEncClientData bool
  31. CanHandleExpiredPasswords bool
  32. SessionTrack bool
  33. DeprecateEOF bool
  34. SSLVerifyServerCert bool
  35. OptionalResultSetMetadata bool
  36. RememberOptions bool
  37. }
  38. type Status struct {
  39. InTrans bool
  40. Autocommit bool
  41. MoreResultsExists bool
  42. QueryNoGoodIndexUsed bool
  43. QueryNoIndexUsed bool
  44. CursorExists bool
  45. LastRowSent bool
  46. DBDropped bool
  47. NoBackslashEscapes bool
  48. MetadataChanged bool
  49. QueryWasSlow bool
  50. PSOutParams bool
  51. InTransReadonly bool
  52. SessionStateChanged bool
  53. }
  54. type Handshake struct {
  55. PacketLength uint64
  56. SequenceID uint64
  57. ProtocolVersion uint64
  58. ServerVersion string
  59. ThreadID uint64
  60. AuthPluginDataPart1 *bytes.Buffer
  61. CapabilityFlags1 *bytes.Buffer
  62. Charset uint64
  63. StatusFlags *bytes.Buffer
  64. CapabilityFlags2 *bytes.Buffer
  65. AuthPluginDataLength uint64
  66. AuthPluginDataPart2 *bytes.Buffer
  67. AuthPluginName string
  68. Capabilities *Capabilities
  69. Status *Status
  70. }
  71. type HandshakeResponse struct {
  72. ClientFlag *Capabilities
  73. MaxPacketSize uint64
  74. CharacterSet uint64
  75. Username string
  76. AuthResponseLength uint64
  77. AuthResponse string
  78. Database string
  79. ClientPluginName string
  80. KeyValues map[string]string
  81. }
  82. type SSLRequest struct {
  83. ClientFlag *Capabilities
  84. MaxPacketSize uint64
  85. CharacterSet uint64
  86. Username string
  87. }
  88. func (c *Conn) decodeCapabilityFlags(hs *Handshake) {
  89. var cfb = append(hs.CapabilityFlags1.Bytes(), hs.CapabilityFlags2.Bytes()...)
  90. capabilities := c.bitmaskToStruct(cfb, hs.Capabilities).(Capabilities)
  91. hs.Capabilities = &capabilities
  92. }
  93. func (c *Conn) decodeStatusFlags(hs *Handshake) {
  94. status := c.bitmaskToStruct(hs.StatusFlags.Bytes(), hs.Status).(Status)
  95. hs.Status = &status
  96. }
  97. func (c *Conn) decodeHandshakePacket() error {
  98. packet := Handshake{}
  99. packet.PacketLength = c.getInt(TypeFixedInt, 3)
  100. packet.SequenceID = c.getInt(TypeFixedInt, 1)
  101. packet.ProtocolVersion = c.getInt(TypeFixedInt, 1)
  102. packet.ServerVersion = c.getString(TypeNullTerminatedString, 0)
  103. packet.ThreadID = c.getInt(TypeFixedInt, 4)
  104. packet.AuthPluginDataPart1 = c.readBytes(8)
  105. c.discardBytes(1)
  106. packet.CapabilityFlags1 = c.readBytes(2)
  107. packet.Charset = c.getInt(TypeFixedInt, 1)
  108. packet.StatusFlags = c.readBytes(2)
  109. c.decodeStatusFlags(&packet)
  110. packet.CapabilityFlags2 = c.readBytes(2)
  111. c.decodeCapabilityFlags(&packet)
  112. packet.AuthPluginDataLength = c.getInt(TypeFixedInt, 1)
  113. c.discardBytes(10)
  114. p1l := uint64(packet.AuthPluginDataPart1.Len())
  115. packet.AuthPluginDataPart2 = c.readBytes(packet.AuthPluginDataLength - p1l)
  116. packet.AuthPluginName = c.getString(TypeNullTerminatedString, 0)
  117. err := c.scanner.Err()
  118. if err != nil {
  119. return err
  120. }
  121. c.Handshake = &packet
  122. return nil
  123. }
  124. func (c *Conn) writeHandshakeResponse() error {
  125. hr := c.HandshakeResponse
  126. cf := c.structToBitmask(hr.ClientFlag)
  127. c.putBytes(cf)
  128. c.putInt(TypeFixedInt, hr.MaxPacketSize, 4)
  129. c.putInt(TypeFixedInt, hr.CharacterSet, 1)
  130. c.putNullBytes(23)
  131. c.putString(TypeNullTerminatedString, hr.Username)
  132. // Perform authentication
  133. salt := append(c.Handshake.AuthPluginDataPart1.Bytes(), c.Handshake.AuthPluginDataPart2.Bytes()...)
  134. password := []byte(hr.AuthResponse)
  135. c.authenticate(salt, password)
  136. // Write database name
  137. if hr.ClientFlag.ConnectWithDB {
  138. c.putString(TypeNullTerminatedString, hr.Database)
  139. }
  140. // Set type of auth plugin based on if it is at the end of the packet.
  141. var t int
  142. if hr.KeyValues != nil {
  143. t = TypeNullTerminatedString
  144. } else {
  145. t = TypeRestOfPacketString
  146. }
  147. // Write auth plugin
  148. if hr.ClientFlag.PluginAuth {
  149. c.putString(t, hr.ClientPluginName)
  150. c.putNullBytes(1)
  151. }
  152. if c.Flush() != nil {
  153. return c.Flush()
  154. }
  155. return nil
  156. }
  157. func (c *Conn) writeSSLRequestPacket() error {
  158. sr := c.NewSSLRequest()
  159. cf := c.structToBitmask(sr.ClientFlag)
  160. c.putBytes(cf)
  161. c.putInt(TypeFixedInt, sr.MaxPacketSize, 4)
  162. c.putInt(TypeFixedInt, sr.CharacterSet, 1)
  163. c.putNullBytes(23)
  164. if c.Flush() != nil {
  165. return c.Flush()
  166. }
  167. return nil
  168. }
  169. func (c *Conn) NewSSLRequest() *SSLRequest {
  170. return &SSLRequest{
  171. ClientFlag: c.HandshakeResponse.ClientFlag,
  172. MaxPacketSize: c.HandshakeResponse.MaxPacketSize,
  173. CharacterSet: c.HandshakeResponse.CharacterSet,
  174. Username: c.HandshakeResponse.Username,
  175. }
  176. }
  177. func (c *Conn) NewHandshakeResponse() *HandshakeResponse {
  178. return &HandshakeResponse{
  179. ClientFlag: &Capabilities{
  180. LongPassword: true,
  181. FoundRows: true,
  182. LongFlag: false,
  183. ConnectWithDB: false,
  184. NoSchema: false,
  185. Compress: false,
  186. ODBC: false,
  187. LocalFiles: false,
  188. IgnoreSpace: true,
  189. Protocol41: true,
  190. Interactive: true,
  191. SSL: c.Config.SSL,
  192. IgnoreSigpipe: false,
  193. Transactions: true,
  194. LegacyProtocol41: false,
  195. SecureConnection: true,
  196. MultiStatements: false,
  197. MultiResults: false,
  198. PSMultiResults: true,
  199. PluginAuth: true,
  200. ConnectAttrs: false,
  201. PluginAuthLenEncClientData: false,
  202. CanHandleExpiredPasswords: false,
  203. SessionTrack: true,
  204. DeprecateEOF: false,
  205. SSLVerifyServerCert: c.Config.VerifyCert,
  206. OptionalResultSetMetadata: false,
  207. RememberOptions: false,
  208. },
  209. MaxPacketSize: MaxPacketSize,
  210. CharacterSet: 45,
  211. Username: c.Config.User,
  212. AuthResponseLength: 0,
  213. AuthResponse: c.Config.Pass,
  214. Database: c.Config.Database,
  215. ClientPluginName: c.Handshake.AuthPluginName,
  216. KeyValues: nil,
  217. }
  218. }
  219. // generate TLS config for client side
  220. // if insecureSkipVerify is set to true, serverName will not be validated
  221. func NewClientTLSConfig(keyPem string, cerPem string, caPem []byte, insecureSkipVerify bool, serverName string) *tls.Config {
  222. config := &tls.Config{
  223. InsecureSkipVerify: !insecureSkipVerify,
  224. ServerName: serverName,
  225. }
  226. if caPem != nil {
  227. ca, err := ioutil.ReadFile(string(caPem))
  228. if err == nil {
  229. pool := x509.NewCertPool()
  230. if !pool.AppendCertsFromPEM(ca) {
  231. panic("failed to add ca PEM")
  232. }
  233. config.RootCAs = pool
  234. }
  235. }
  236. if keyPem != "" && cerPem != "" {
  237. cert, err := tls.LoadX509KeyPair(cerPem, keyPem)
  238. if err != nil {
  239. panic(err)
  240. }
  241. config.Certificates = []tls.Certificate{cert}
  242. }
  243. return config
  244. }