diff --git a/proxy/wireguard/client.go b/proxy/wireguard/client.go index 2e5fc286..560d7424 100644 --- a/proxy/wireguard/client.go +++ b/proxy/wireguard/client.go @@ -114,6 +114,8 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer) } // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer + // Use a detached context for the bind to avoid tying all peer connections + // to a single request context. This allows multiple peers to work independently. h.bind = &netBindClient{ netBind: netBind{ dns: h.dns, @@ -123,7 +125,7 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer) }, workers: int(h.conf.NumWorkers), }, - ctx: ctx, + ctx: core.ToBackgroundDetachedContext(ctx), dialer: dialer, reserved: h.conf.Reserved, } diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index 989abd54..6144f5c7 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -4,7 +4,6 @@ import ( "context" goerrors "errors" "io" - "sync" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" @@ -27,10 +26,6 @@ var nullDestination = net.TCPDestination(net.AnyIP, 0) type Server struct { bindServer *netBindServer - // Use a mutex-protected default routing info for forwarded connections - // Since we cannot determine which peer initiated a forwarded connection from gvisor, - // we use the most recently set routing info as default - infoMutex sync.RWMutex info routingInfo policyManager policy.Manager } @@ -83,25 +78,11 @@ func (*Server) Network() []net.Network { // Process implements proxy.Inbound. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { - // Use double-checked locking to safely handle concurrent access to routing info - // Only update if not set or if dispatcher is different - // First check without write lock for better concurrency - s.infoMutex.RLock() - needsUpdate := s.info.dispatcher == nil || s.info.dispatcher != dispatcher - s.infoMutex.RUnlock() - - if needsUpdate { - s.infoMutex.Lock() - // Double-check after acquiring write lock - if s.info.dispatcher == nil || s.info.dispatcher != dispatcher { - s.info = routingInfo{ - ctx: ctx, - dispatcher: dispatcher, - inboundTag: session.InboundFromContext(ctx), - contentTag: session.ContentFromContext(ctx), - } - } - s.infoMutex.Unlock() + s.info = routingInfo{ + ctx: ctx, + dispatcher: dispatcher, + inboundTag: session.InboundFromContext(ctx), + contentTag: session.ContentFromContext(ctx), } ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) @@ -139,23 +120,18 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con } func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { - // Safely read routing info - s.infoMutex.RLock() - info := s.info - s.infoMutex.RUnlock() - - if info.dispatcher == nil { - errors.LogError(info.ctx, "unexpected: dispatcher == nil") + if s.info.dispatcher == nil { + errors.LogError(s.info.ctx, "unexpected: dispatcher == nil") return } defer conn.Close() - ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(info.ctx)) + ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) sid := session.NewID() ctx = c.ContextWithID(ctx, sid) inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs) - if info.inboundTag != nil { - inbound = *info.inboundTag + if s.info.inboundTag != nil { + inbound = *s.info.inboundTag } inbound.Name = "wireguard" inbound.CanSpliceCopy = 3 @@ -165,8 +141,8 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { // Currently we have no way to link to the original source address inbound.Source = net.DestinationFromAddr(conn.RemoteAddr()) ctx = session.ContextWithInbound(ctx, &inbound) - if info.contentTag != nil { - ctx = session.ContextWithContent(ctx, info.contentTag) + if s.info.contentTag != nil { + ctx = session.ContextWithContent(ctx, s.info.contentTag) } ctx = session.SubContextFromMuxInbound(ctx) @@ -180,7 +156,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { Reason: "", }) - link, err := info.dispatcher.Dispatch(ctx, dest) + link, err := s.info.dispatcher.Dispatch(ctx, dest) if err != nil { errors.LogErrorInner(ctx, err, "dispatch connection") }