vhost: use new readClientHello function - From Upstream #2504

This commit is contained in:
夜棂依 2024-07-30 08:39:33 +08:00 committed by GitHub
parent 4d00e51df7
commit 6a0b1789da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,33 +15,14 @@
package vhost
import (
"fmt"
"crypto/tls"
"io"
"strings"
"net"
"time"
frpNet "github.com/fatedier/frp/utils/net"
gnet "github.com/fatedier/golib/net"
"github.com/fatedier/golib/pool"
)
const (
typeClientHello uint8 = 1 // Type client hello
)
// TLS extension numbers
const (
extensionServerName uint16 = 0
extensionStatusRequest uint16 = 5
extensionSupportedCurves uint16 = 10
extensionSupportedPoints uint16 = 11
extensionSignatureAlgorithms uint16 = 13
extensionALPN uint16 = 16
extensionSCT uint16 = 18
extensionSessionTicket uint16 = 35
extensionNextProtoNeg uint16 = 13172 // not IANA assigned
extensionRenegotiationInfo uint16 = 0xff01
)
type HttpsMuxer struct {
@ -53,143 +34,49 @@ func NewHttpsMuxer(listener frpNet.Listener, timeout time.Duration) (*HttpsMuxer
return &HttpsMuxer{mux}, err
}
func readHandshake(rd io.Reader) (host string, err error) {
data := pool.GetBuf(1024)
origin := data
defer pool.PutBuf(origin)
_, err = io.ReadFull(rd, data[:47])
if err != nil {
return
}
length, err := rd.Read(data[47:])
if err != nil {
return
} else {
length += 47
}
data = data[:length]
if uint8(data[5]) != typeClientHello {
err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5]))
return
}
// session
sessionIdLen := int(data[43])
if sessionIdLen > 32 || len(data) < 44+sessionIdLen {
err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIdLen)
return
}
data = data[44+sessionIdLen:]
if len(data) < 2 {
err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data))
return
}
// cipher suite numbers
cipherSuiteLen := int(data[0])<<8 | int(data[1])
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data))
return
}
data = data[2+cipherSuiteLen:]
if len(data) < 1 {
err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen)
return
}
// compression method
compressionMethodsLen := int(data[0])
if len(data) < 1+compressionMethodsLen {
err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen)
return
}
data = data[1+compressionMethodsLen:]
if len(data) == 0 {
// ClientHello is optionally followed by extension data
err = fmt.Errorf("readHandshake: there is no extension data to get servername")
return
}
if len(data) < 2 {
err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short", len(data))
return
}
extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:]
if extensionsLength != len(data) {
err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data))
return
}
for len(data) != 0 {
if len(data) < 4 {
err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data))
return
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length)
return
}
switch extension {
case extensionRenegotiationInfo:
if length != 1 || data[0] != 0 {
err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length)
return
}
case extensionNextProtoNeg:
case extensionStatusRequest:
case extensionServerName:
d := data[:length]
if len(d) < 2 {
err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d))
return
}
namesLen := int(d[0])<<8 | int(d[1])
d = d[2:]
if len(d) != namesLen {
err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d))
return
}
for len(d) > 0 {
if len(d) < 3 {
err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d))
return
}
nameType := d[0]
nameLen := int(d[1])<<8 | int(d[2])
d = d[3:]
if len(d) < nameLen {
err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d))
return
}
if nameType == 0 {
serverName := string(d[:nameLen])
host = strings.TrimSpace(serverName)
return host, nil
}
d = d[nameLen:]
}
}
data = data[length:]
}
err = fmt.Errorf("Unknow error")
return
}
func GetHttpsHostname(c frpNet.Conn) (_ frpNet.Conn, _ map[string]string, err error) {
reqInfoMap := make(map[string]string, 0)
sc, rd := gnet.NewSharedConn(c)
host, err := readHandshake(rd)
clientHello, err := readClientHello(rd)
if err != nil {
return nil, reqInfoMap, err
}
reqInfoMap["Host"] = host
reqInfoMap["Host"] = clientHello.ServerName
reqInfoMap["Scheme"] = "https"
return frpNet.WrapConn(sc), reqInfoMap, nil
}
func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
var hello *tls.ClientHelloInfo
// Note that Handshake always fails because the readOnlyConn is not a real connection.
// As long as the Client Hello is successfully read, the failure should only happen after GetConfigForClient is called,
// so we only care about the error if hello was never set.
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = &tls.ClientHelloInfo{}
*hello = *argHello
return nil, nil
},
}).Handshake()
if hello == nil {
return nil, err
}
return hello, nil
}
type readOnlyConn struct {
reader io.Reader
}
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (conn readOnlyConn) Close() error { return nil }
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }