mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-07-04 02:38:42 +00:00
Socks5 server: More standard UDP ASSOCIATE (RFC 1928) (#6149)
https://github.com/XTLS/Xray-core/pull/6149#issuecomment-4529069218 Fixes https://github.com/XTLS/Xray-core/issues/6145#issuecomment-4467482623 --------- Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
This commit is contained in:
+32
-13
@@ -1,14 +1,17 @@
|
|||||||
package socks
|
package socks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
|
gonet "net"
|
||||||
|
|
||||||
"github.com/xtls/xray-core/common"
|
"github.com/xtls/xray-core/common"
|
||||||
"github.com/xtls/xray-core/common/buf"
|
"github.com/xtls/xray-core/common/buf"
|
||||||
"github.com/xtls/xray-core/common/errors"
|
"github.com/xtls/xray-core/common/errors"
|
||||||
"github.com/xtls/xray-core/common/net"
|
"github.com/xtls/xray-core/common/net"
|
||||||
"github.com/xtls/xray-core/common/protocol"
|
"github.com/xtls/xray-core/common/protocol"
|
||||||
|
"github.com/xtls/xray-core/transport/internet"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -137,13 +140,13 @@ func (s *ServerSession) auth5(nMethod byte, reader io.Reader, writer io.Writer)
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
|
func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer net.Conn) (*protocol.RequestHeader, *TempUDPConn, error) {
|
||||||
var (
|
var (
|
||||||
username string
|
username string
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if username, err = s.auth5(nMethod, reader, writer); err != nil {
|
if username, err = s.auth5(nMethod, reader, writer); err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var cmd byte
|
var cmd byte
|
||||||
@@ -151,7 +154,7 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri
|
|||||||
buffer := buf.StackNew()
|
buffer := buf.StackNew()
|
||||||
if _, err := buffer.ReadFullFrom(reader, 3); err != nil {
|
if _, err := buffer.ReadFullFrom(reader, 3); err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
return nil, errors.New("failed to read request").Base(err)
|
return nil, nil, errors.New("failed to read request").Base(err)
|
||||||
}
|
}
|
||||||
cmd = buffer.Byte(1)
|
cmd = buffer.Byte(1)
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
@@ -168,28 +171,29 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri
|
|||||||
case cmdUDPAssociate:
|
case cmdUDPAssociate:
|
||||||
if !s.config.UdpEnabled {
|
if !s.config.UdpEnabled {
|
||||||
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
|
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
|
||||||
return nil, errors.New("UDP is not enabled.")
|
return nil, nil, errors.New("UDP is not enabled.")
|
||||||
}
|
}
|
||||||
request.Command = protocol.RequestCommandUDP
|
request.Command = protocol.RequestCommandUDP
|
||||||
case cmdTCPBind:
|
case cmdTCPBind:
|
||||||
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
|
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
|
||||||
return nil, errors.New("TCP bind is not supported.")
|
return nil, nil, errors.New("TCP bind is not supported.")
|
||||||
default:
|
default:
|
||||||
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
|
writeSocks5Response(writer, statusCmdNotSupport, net.AnyIP, net.Port(0))
|
||||||
return nil, errors.New("unknown command ", cmd)
|
return nil, nil, errors.New("unknown command ", cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Version = socks5Version
|
request.Version = socks5Version
|
||||||
|
|
||||||
addr, port, err := addrParser.ReadAddressPort(nil, reader)
|
addr, port, err := addrParser.ReadAddressPort(nil, reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("failed to read address").Base(err)
|
return nil, nil, errors.New("failed to read address").Base(err)
|
||||||
}
|
}
|
||||||
request.Address = addr
|
request.Address = addr
|
||||||
request.Port = port
|
request.Port = port
|
||||||
|
|
||||||
responseAddress := s.address
|
responseAddress := s.address
|
||||||
responsePort := s.port
|
responsePort := s.port
|
||||||
|
var tempUDPConn *TempUDPConn
|
||||||
//nolint:gocritic // Use if else chain for clarity
|
//nolint:gocritic // Use if else chain for clarity
|
||||||
if request.Command == protocol.RequestCommandUDP {
|
if request.Command == protocol.RequestCommandUDP {
|
||||||
if s.config.Address != nil {
|
if s.config.Address != nil {
|
||||||
@@ -199,20 +203,34 @@ func (s *ServerSession) handshake5(nMethod byte, reader io.Reader, writer io.Wri
|
|||||||
// Use conn.LocalAddr() IP as remote address in the response by default
|
// Use conn.LocalAddr() IP as remote address in the response by default
|
||||||
responseAddress = s.localAddress
|
responseAddress = s.localAddress
|
||||||
}
|
}
|
||||||
|
udpHub, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: responseAddress.IP(), Port: 0}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.New("failed to create UDP listener").Base(err)
|
||||||
|
}
|
||||||
|
responsePort = net.Port(udpHub.LocalAddr().(*net.UDPAddr).Port)
|
||||||
|
expectedRemote := &gonet.UDPAddr{}
|
||||||
|
if request.Address.IP().IsUnspecified() {
|
||||||
|
expectedRemote.IP = writer.RemoteAddr().(*net.TCPAddr).IP // unix?
|
||||||
|
} else {
|
||||||
|
expectedRemote.IP = request.Address.IP() // panic?
|
||||||
|
expectedRemote.Port = int(request.Port) // 0 is allowed
|
||||||
|
}
|
||||||
|
tempUDPConn = NewTempUDPConn(udpHub, writer, expectedRemote)
|
||||||
}
|
}
|
||||||
if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil {
|
if err := writeSocks5Response(writer, statusSuccess, responseAddress, responsePort); err != nil {
|
||||||
return nil, err
|
common.CloseIfExists(tempUDPConn)
|
||||||
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return request, nil
|
return request, tempUDPConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake performs a Socks4/4a/5 handshake.
|
// Handshake performs a Socks4/4a/5 handshake.
|
||||||
func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
|
func (s *ServerSession) Handshake(reader io.Reader, writer net.Conn) (*protocol.RequestHeader, *TempUDPConn, error) {
|
||||||
buffer := buf.StackNew()
|
buffer := buf.StackNew()
|
||||||
if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
|
if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
return nil, errors.New("insufficient header").Base(err)
|
return nil, nil, errors.New("insufficient header").Base(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
version := buffer.Byte(0)
|
version := buffer.Byte(0)
|
||||||
@@ -221,11 +239,12 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
|||||||
|
|
||||||
switch version {
|
switch version {
|
||||||
case socks4Version:
|
case socks4Version:
|
||||||
return s.handshake4(cmd, reader, writer)
|
header, err := s.handshake4(cmd, reader, writer)
|
||||||
|
return header, nil, err
|
||||||
case socks5Version:
|
case socks5Version:
|
||||||
return s.handshake5(cmd, reader, writer)
|
return s.handshake5(cmd, reader, writer)
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unknown Socks version: ", version)
|
return nil, nil, errors.New("unknown Socks version: ", version)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+16
-24
@@ -29,7 +29,6 @@ type Server struct {
|
|||||||
config *ServerConfig
|
config *ServerConfig
|
||||||
policyManager policy.Manager
|
policyManager policy.Manager
|
||||||
cone bool
|
cone bool
|
||||||
udpFilter *UDPFilter
|
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,7 +45,6 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
|
|||||||
}
|
}
|
||||||
if config.AuthType == AuthType_PASSWORD {
|
if config.AuthType == AuthType_PASSWORD {
|
||||||
httpConfig.Accounts = config.Accounts
|
httpConfig.Accounts = config.Accounts
|
||||||
s.udpFilter = new(UDPFilter) // We only use this when auth is enabled
|
|
||||||
}
|
}
|
||||||
s.httpServer, _ = http.NewServer(ctx, httpConfig)
|
s.httpServer, _ = http.NewServer(ctx, httpConfig)
|
||||||
return s, nil
|
return s, nil
|
||||||
@@ -60,11 +58,7 @@ func (s *Server) policy() policy.Session {
|
|||||||
|
|
||||||
// Network implements proxy.Inbound.
|
// Network implements proxy.Inbound.
|
||||||
func (s *Server) Network() []net.Network {
|
func (s *Server) Network() []net.Network {
|
||||||
list := []net.Network{net.Network_TCP}
|
return []net.Network{net.Network_TCP}
|
||||||
if s.config.UdpEnabled {
|
|
||||||
list = append(list, net.Network_UDP)
|
|
||||||
}
|
|
||||||
return list
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process implements proxy.Inbound.
|
// Process implements proxy.Inbound.
|
||||||
@@ -94,8 +88,6 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
|
|||||||
return s.httpServer.ProcessWithFirstbyte(ctx, network, conn, dispatcher, firstbyte...)
|
return s.httpServer.ProcessWithFirstbyte(ctx, network, conn, dispatcher, firstbyte...)
|
||||||
}
|
}
|
||||||
return s.processTCP(ctx, conn, dispatcher, firstbyte)
|
return s.processTCP(ctx, conn, dispatcher, firstbyte)
|
||||||
case net.Network_UDP:
|
|
||||||
return s.handleUDPPayload(ctx, conn, dispatcher)
|
|
||||||
default:
|
default:
|
||||||
return errors.New("unknown network: ", network)
|
return errors.New("unknown network: ", network)
|
||||||
}
|
}
|
||||||
@@ -126,7 +118,8 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche
|
|||||||
Reader: buf.NewReader(conn),
|
Reader: buf.NewReader(conn),
|
||||||
Buffer: buf.MultiBuffer{buf.FromBytes(firstbyte)},
|
Buffer: buf.MultiBuffer{buf.FromBytes(firstbyte)},
|
||||||
}
|
}
|
||||||
request, err := svrSession.Handshake(reader, conn)
|
request, tempUDPConn, err := svrSession.Handshake(reader, conn)
|
||||||
|
defer common.CloseIfExists(tempUDPConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if inbound.Source.IsValid() {
|
if inbound.Source.IsValid() {
|
||||||
log.Record(&log.AccessMessage{
|
log.Record(&log.AccessMessage{
|
||||||
@@ -170,26 +163,25 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche
|
|||||||
}
|
}
|
||||||
|
|
||||||
if request.Command == protocol.RequestCommandUDP {
|
if request.Command == protocol.RequestCommandUDP {
|
||||||
if s.udpFilter != nil {
|
if tempUDPConn == nil {
|
||||||
s.udpFilter.Add(conn.RemoteAddr())
|
return errors.New("UDP associate with listen port failed")
|
||||||
}
|
}
|
||||||
return s.handleUDP(conn)
|
tempUDPConn.SetTimeout(plcy.Timeouts.ConnectionIdle)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errCh <- s.handleUDPPayload(ctx, tempUDPConn, dispatcher)
|
||||||
|
}()
|
||||||
|
// Associated TCP keeps the UDP alive
|
||||||
|
// Close UDP if TCP connection is closed
|
||||||
|
// Or Close TCP if UDP is idle timeout
|
||||||
|
io.Copy(buf.DiscardBytes, conn)
|
||||||
|
tempUDPConn.Close()
|
||||||
|
return <-errCh
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Server) handleUDP(c io.Reader) error {
|
|
||||||
// The TCP connection closes after this method returns. We need to wait until
|
|
||||||
// the client closes it.
|
|
||||||
return common.Error2(io.Copy(buf.DiscardBytes, c))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
|
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
|
||||||
if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) {
|
|
||||||
errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
|
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
|
||||||
payload := packet.Payload
|
payload := packet.Payload
|
||||||
errors.LogDebug(ctx, "writing back UDP response with ", payload.Len(), " bytes")
|
errors.LogDebug(ctx, "writing back UDP response with ", payload.Len(), " bytes")
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
package socks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/xtls/xray-core/common/signal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewTempUDPConn(udpConn net.PacketConn, tcpConn net.Conn, expectedRemote *net.UDPAddr) *TempUDPConn {
|
||||||
|
t := &TempUDPConn{
|
||||||
|
PacketConn: udpConn,
|
||||||
|
AssociatedTCPConn: tcpConn,
|
||||||
|
}
|
||||||
|
t.ExpectedRemote.Store(expectedRemote)
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// TempUDPConn wait for the first packet to determine the remote address
|
||||||
|
// SetTimeout MUST be called before any read/write operation
|
||||||
|
type TempUDPConn struct {
|
||||||
|
net.PacketConn
|
||||||
|
AssociatedTCPConn net.Conn
|
||||||
|
ExpectedRemote atomic.Pointer[net.UDPAddr]
|
||||||
|
Timer *signal.ActivityTimer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TempUDPConn) Read(b []byte) (n int, err error) {
|
||||||
|
var remote net.Addr
|
||||||
|
for {
|
||||||
|
n, remote, err = c.PacketConn.ReadFrom(b)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
remote := remote.(*net.UDPAddr)
|
||||||
|
expected := c.ExpectedRemote.Load()
|
||||||
|
if remote.IP.Equal(expected.IP) {
|
||||||
|
if remote.Port == expected.Port {
|
||||||
|
c.Timer.Update()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if expected.Port == 0 {
|
||||||
|
c.ExpectedRemote.Store(remote)
|
||||||
|
c.Timer.Update()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TempUDPConn) Write(b []byte) (n int, err error) {
|
||||||
|
c.Timer.Update()
|
||||||
|
return c.PacketConn.WriteTo(b, c.ExpectedRemote.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TempUDPConn) RemoteAddr() net.Addr {
|
||||||
|
return c.ExpectedRemote.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TempUDPConn) SetTimeout(d time.Duration) {
|
||||||
|
c.Timer = signal.CancelAfterInactivity(context.Background(), func() {
|
||||||
|
c.Close()
|
||||||
|
}, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TempUDPConn) Close() error {
|
||||||
|
c.Timer.SetTimeout(0)
|
||||||
|
c.AssociatedTCPConn.Close()
|
||||||
|
return c.PacketConn.Close()
|
||||||
|
}
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
package socks
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
/*
|
|
||||||
In the sock implementation of * ray, UDP authentication is flawed and can be bypassed.
|
|
||||||
Tracking a UDP connection may be a bit troublesome.
|
|
||||||
Here is a simple solution.
|
|
||||||
We create a filter, add remote IP to the pool when it try to establish a UDP connection with auth.
|
|
||||||
And drop UDP packets from unauthorized IP.
|
|
||||||
After discussion, we believe it is not necessary to add a timeout mechanism to this filter.
|
|
||||||
*/
|
|
||||||
|
|
||||||
type UDPFilter struct {
|
|
||||||
ips sync.Map
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *UDPFilter) Add(addr net.Addr) bool {
|
|
||||||
ip, _, _ := net.SplitHostPort(addr.String())
|
|
||||||
f.ips.Store(ip, true)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *UDPFilter) Check(addr net.Addr) bool {
|
|
||||||
ip, _, _ := net.SplitHostPort(addr.String())
|
|
||||||
_, ok := f.ips.Load(ip)
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user