|
@@ -7,7 +7,7 @@ import (
|
|
|
"fmt"
|
|
"fmt"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-type AuthResponse struct {
|
|
|
|
|
|
|
+type AuthResponsePacket struct {
|
|
|
PacketLength uint64
|
|
PacketLength uint64
|
|
|
SequenceID uint64
|
|
SequenceID uint64
|
|
|
Status uint64
|
|
Status uint64
|
|
@@ -15,8 +15,8 @@ type AuthResponse struct {
|
|
|
AuthPluginData *bytes.Buffer
|
|
AuthPluginData *bytes.Buffer
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (c *Conn) decodeAuthResponsePacket() (*AuthResponse, error) {
|
|
|
|
|
- packet := AuthResponse{}
|
|
|
|
|
|
|
+func (c *Conn) decodeAuthResponsePacket() (*AuthResponsePacket, error) {
|
|
|
|
|
+ packet := AuthResponsePacket{}
|
|
|
|
|
|
|
|
packet.PacketLength = c.getInt(TypeFixedInt, 3)
|
|
packet.PacketLength = c.getInt(TypeFixedInt, 3)
|
|
|
packet.SequenceID = c.getInt(TypeFixedInt, 1)
|
|
packet.SequenceID = c.getInt(TypeFixedInt, 1)
|
|
@@ -32,23 +32,31 @@ func (c *Conn) decodeAuthResponsePacket() (*AuthResponse, error) {
|
|
|
return &packet, err
|
|
return &packet, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (c *Conn) writeAuthSwitchPacket() {
|
|
|
|
|
|
|
+func (c *Conn) writeAuthSwitchPacket(ap *AuthResponsePacket) error {
|
|
|
|
|
+ salt := ap.AuthPluginData.Bytes()
|
|
|
|
|
+ password := []byte(c.HandshakeResponse.AuthResponse)
|
|
|
|
|
+ c.authenticate(salt, password)
|
|
|
|
|
|
|
|
|
|
+ if c.Flush() != nil {
|
|
|
|
|
+ return c.Flush()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (c *Conn) authenticate(hr *HandshakeResponse) {
|
|
|
|
|
|
|
+func (c *Conn) authenticate(salt []byte, password []byte) {
|
|
|
var ar []byte
|
|
var ar []byte
|
|
|
- salt := append(c.Handshake.AuthPluginDataPart1.Bytes(), c.Handshake.AuthPluginDataPart2.Bytes()...)
|
|
|
|
|
- password := []byte(hr.AuthResponse)
|
|
|
|
|
- fmt.Println(hr.AuthResponse)
|
|
|
|
|
|
|
|
|
|
|
|
+ salt = salt[:20] // trim null byte from end.
|
|
|
switch c.Handshake.AuthPluginName {
|
|
switch c.Handshake.AuthPluginName {
|
|
|
case "mysql_native_password":
|
|
case "mysql_native_password":
|
|
|
ar = c.nativeSha1Auth(salt, password)
|
|
ar = c.nativeSha1Auth(salt, password)
|
|
|
case "caching_sha2_password":
|
|
case "caching_sha2_password":
|
|
|
|
|
+ fmt.Println(len(salt))
|
|
|
ar = c.cachingSha2Auth(salt, password)
|
|
ar = c.cachingSha2Auth(salt, password)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ hr := c.HandshakeResponse
|
|
|
hr.AuthResponseLength = uint64(len(ar))
|
|
hr.AuthResponseLength = uint64(len(ar))
|
|
|
if hr.ClientFlag.PluginAuthLenEncClientData {
|
|
if hr.ClientFlag.PluginAuthLenEncClientData {
|
|
|
c.putInt(TypeLenEncInt, hr.AuthResponseLength, 0)
|
|
c.putInt(TypeLenEncInt, hr.AuthResponseLength, 0)
|