Fix race condition in WireGuard server with concurrent peer connections

Add mutex protection to server.go to prevent race condition when multiple
peers connect simultaneously. The shared routingInfo field was being
overwritten by concurrent Process() calls, causing connections to fail.

- Add sync.RWMutex to protect access to routing info
- Only update routing info if not already set or dispatcher changed
- Use local copy of routing info in forwardConnection to avoid races
- Existing tests pass

Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-01-09 10:28:10 +00:00
parent a99fe66467
commit 385867e82b
+29 -13
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
goerrors "errors" goerrors "errors"
"io" "io"
"sync"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
@@ -26,6 +27,10 @@ var nullDestination = net.TCPDestination(net.AnyIP, 0)
type Server struct { type Server struct {
bindServer *netBindServer bindServer *netBindServer
// Use a mutex-protected default routing info for forwarded connections
// Since we cannot determine which peer initiated a forwarded connection from gvisor,
// we use the most recently set routing info as default
infoMutex sync.RWMutex
info routingInfo info routingInfo
policyManager policy.Manager policyManager policy.Manager
} }
@@ -78,12 +83,18 @@ func (*Server) Network() []net.Network {
// Process implements proxy.Inbound. // Process implements proxy.Inbound.
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
s.info = routingInfo{ // Use RWMutex to safely handle concurrent access to routing info
ctx: ctx, // Only update if not set or if dispatcher is different
dispatcher: dispatcher, s.infoMutex.Lock()
inboundTag: session.InboundFromContext(ctx), if s.info.dispatcher == nil || s.info.dispatcher != dispatcher {
contentTag: session.ContentFromContext(ctx), s.info = routingInfo{
ctx: ctx,
dispatcher: dispatcher,
inboundTag: session.InboundFromContext(ctx),
contentTag: session.ContentFromContext(ctx),
}
} }
s.infoMutex.Unlock()
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
if err != nil { if err != nil {
@@ -120,18 +131,23 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
} }
func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
if s.info.dispatcher == nil { // Safely read routing info
errors.LogError(s.info.ctx, "unexpected: dispatcher == nil") s.infoMutex.RLock()
info := s.info
s.infoMutex.RUnlock()
if info.dispatcher == nil {
errors.LogError(info.ctx, "unexpected: dispatcher == nil")
return return
} }
defer conn.Close() defer conn.Close()
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(info.ctx))
sid := session.NewID() sid := session.NewID()
ctx = c.ContextWithID(ctx, sid) ctx = c.ContextWithID(ctx, sid)
inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs) inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
if s.info.inboundTag != nil { if info.inboundTag != nil {
inbound = *s.info.inboundTag inbound = *info.inboundTag
} }
inbound.Name = "wireguard" inbound.Name = "wireguard"
inbound.CanSpliceCopy = 3 inbound.CanSpliceCopy = 3
@@ -141,8 +157,8 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
// Currently we have no way to link to the original source address // Currently we have no way to link to the original source address
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr()) inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
ctx = session.ContextWithInbound(ctx, &inbound) ctx = session.ContextWithInbound(ctx, &inbound)
if s.info.contentTag != nil { if info.contentTag != nil {
ctx = session.ContextWithContent(ctx, s.info.contentTag) ctx = session.ContextWithContent(ctx, info.contentTag)
} }
ctx = session.SubContextFromMuxInbound(ctx) ctx = session.SubContextFromMuxInbound(ctx)
@@ -156,7 +172,7 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
Reason: "", Reason: "",
}) })
link, err := s.info.dispatcher.Dispatch(ctx, dest) link, err := info.dispatcher.Dispatch(ctx, dest)
if err != nil { if err != nil {
errors.LogErrorInner(ctx, err, "dispatch connection") errors.LogErrorInner(ctx, err, "dispatch connection")
} }