diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index 3b7d183..c68fe70 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -523,6 +523,9 @@ func HandleTcpWorkConnection(localInfo *config.LocalSvrConf, proxyPlugin plugin. // check if we need to send proxy protocol info if baseInfo.ProxyProtocolVersion != "" { if m.SrcAddr != "" && m.SrcPort != 0 { + if m.DstAddr == "" { + m.DstAddr = "127.0.0.1" + } h := &pp.Header{ Command: pp.PROXY, SourceAddress: net.ParseIP(m.SrcAddr), diff --git a/server/proxy/http.go b/server/proxy/http.go index e609222..c5bc1ac 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -16,6 +16,7 @@ package proxy import ( "io" + "net" "strings" "github.com/fatedier/frp/g" @@ -97,8 +98,14 @@ func (pxy *HttpProxy) GetConf() config.ProxyConf { return pxy.cfg } -func (pxy *HttpProxy) GetRealConn() (workConn frpNet.Conn, err error) { - tmpConn, errRet := pxy.GetWorkConnFromPool(nil, nil) +func (pxy *HttpProxy) GetRealConn(remoteAddr string) (workConn frpNet.Conn, err error) { + rAddr, errRet := net.ResolveTCPAddr("tcp", remoteAddr) + if errRet != nil { + pxy.Warn("resolve TCP addr [%s] error: %v", remoteAddr, errRet) + // we do not return error here since remoteAddr is not necessary for proxies without proxy protocol enabled + } + + tmpConn, errRet := pxy.GetWorkConnFromPool(rAddr, nil) if errRet != nil { err = errRet return diff --git a/utils/vhost/newhttp.go b/utils/vhost/http.go similarity index 95% rename from utils/vhost/newhttp.go rename to utils/vhost/http.go index 59a4a0e..7bbc361 100644 --- a/utils/vhost/newhttp.go +++ b/utils/vhost/http.go @@ -89,13 +89,15 @@ func NewHttpReverseProxy(option HttpReverseProxyOptions) *HttpReverseProxy { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { url := ctx.Value("url").(string) host := getHostFromAddr(ctx.Value("host").(string)) - return rp.CreateConnection(host, url) + remote := ctx.Value("remote").(string) + return rp.CreateConnection(host, url, remote) }, }, WebSocketDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { url := ctx.Value("url").(string) host := getHostFromAddr(ctx.Value("host").(string)) - return rp.CreateConnection(host, url) + remote := ctx.Value("remote").(string) + return rp.CreateConnection(host, url, remote) }, BufferPool: newWrapPool(), ErrorLog: log.New(newWrapLogger(), "", 0), @@ -138,12 +140,12 @@ func (rp *HttpReverseProxy) GetHeaders(domain string, location string) (headers return } -func (rp *HttpReverseProxy) CreateConnection(domain string, location string) (net.Conn, error) { +func (rp *HttpReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) { vr, ok := rp.getVhost(domain, location) if ok { fn := vr.payload.(*VhostRouteConfig).CreateConnFn if fn != nil { - return fn() + return fn(remoteAddr) } } return nil, fmt.Errorf("%v: %s %s", ErrNoDomain, domain, location) diff --git a/utils/vhost/reverseproxy.go b/utils/vhost/reverseproxy.go index 6b0c292..5662589 100644 --- a/utils/vhost/reverseproxy.go +++ b/utils/vhost/reverseproxy.go @@ -158,6 +158,7 @@ func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) req = req.WithContext(context.WithValue(req.Context(), "url", req.URL.Path)) req = req.WithContext(context.WithValue(req.Context(), "host", req.Host)) + req = req.WithContext(context.WithValue(req.Context(), "remote", req.RemoteAddr)) targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "") if err != nil { @@ -215,6 +216,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { // Modify for frp outreq = outreq.WithContext(context.WithValue(outreq.Context(), "url", req.URL.Path)) outreq = outreq.WithContext(context.WithValue(outreq.Context(), "host", req.Host)) + outreq = outreq.WithContext(context.WithValue(outreq.Context(), "remote", req.RemoteAddr)) p.Director(outreq) outreq.Close = false diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go index 2e38652..f366e1e 100644 --- a/utils/vhost/vhost.go +++ b/utils/vhost/vhost.go @@ -51,7 +51,7 @@ func NewVhostMuxer(listener frpNet.Listener, vhostFunc muxFunc, authFunc httpAut return mux, nil } -type CreateConnFunc func() (frpNet.Conn, error) +type CreateConnFunc func(remoteAddr string) (frpNet.Conn, error) type VhostRouteConfig struct { Domain string