authentication.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package binlog
  2. import (
  3. "bytes"
  4. "crypto/sha1"
  5. "crypto/sha256"
  6. "fmt"
  7. )
  8. type AuthResponse struct {
  9. PacketLength uint64
  10. SequenceID uint64
  11. Status uint64
  12. PluginName string
  13. AuthPluginData *bytes.Buffer
  14. }
  15. func (c *Conn) decodeAuthResponsePacket() (*AuthResponse, error) {
  16. packet := AuthResponse{}
  17. packet.PacketLength = c.getInt(TypeFixedInt, 3)
  18. packet.SequenceID = c.getInt(TypeFixedInt, 1)
  19. packet.Status = c.getInt(TypeFixedInt, 1)
  20. packet.PluginName = c.getString(TypeNullTerminatedString, 0)
  21. packet.AuthPluginData = c.readBytes(20)
  22. err := c.scanner.Err()
  23. if err != nil {
  24. return nil, err
  25. }
  26. return &packet, err
  27. }
  28. func (c *Conn) writeAuthSwitchPacket() {
  29. }
  30. func (c *Conn) authenticate(hr *HandshakeResponse) {
  31. var ar []byte
  32. salt := append(c.Handshake.AuthPluginDataPart1.Bytes(), c.Handshake.AuthPluginDataPart2.Bytes()...)
  33. password := []byte(hr.AuthResponse)
  34. fmt.Println(hr.AuthResponse)
  35. switch c.Handshake.AuthPluginName {
  36. case "mysql_native_password":
  37. ar = c.nativeSha1Auth(salt, password)
  38. case "caching_sha2_password":
  39. ar = c.cachingSha2Auth(salt, password)
  40. }
  41. hr.AuthResponseLength = uint64(len(ar))
  42. if hr.ClientFlag.PluginAuthLenEncClientData {
  43. c.putInt(TypeLenEncInt, hr.AuthResponseLength, 0)
  44. c.putBytes(ar)
  45. } else if hr.ClientFlag.SecureConnection {
  46. c.putInt(TypeFixedInt, hr.AuthResponseLength, 1)
  47. c.putBytes(ar)
  48. } else {
  49. c.putString(TypeNullTerminatedString, string(ar))
  50. }
  51. }
  52. func (c *Conn) nativeSha1Auth(salt []byte, password []byte) []byte {
  53. pHash := c.sha1Hash(password)
  54. pHashHash := c.sha1Hash(pHash)
  55. spHash := c.sha1Hash(append(salt, pHashHash...))
  56. for i := range pHash {
  57. pHash[i] ^= spHash[i]
  58. }
  59. return pHash
  60. }
  61. func (c *Conn) cachingSha2Auth(salt []byte, password []byte) []byte {
  62. if len(password) < 1 {
  63. return nil
  64. }
  65. pHash := c.sha256Hash(password)
  66. pHashHash := c.sha256Hash(pHash)
  67. pHashHashHash := c.sha256Hash(pHashHash)
  68. authData := c.sha256Hash(append(pHashHashHash, salt...))
  69. for i := range pHash {
  70. pHash[i] ^= authData[i]
  71. }
  72. return pHash
  73. }
  74. func (c *Conn) sha1Hash(word []byte) []byte {
  75. s := sha1.New()
  76. s.Write(word)
  77. return s.Sum(nil)
  78. }
  79. func (c *Conn) sha256Hash(word []byte) []byte {
  80. s := sha256.New()
  81. s.Write(word)
  82. return s.Sum(nil)
  83. }