Files
Xray-core/proxy/wireguard/server.go
T
copilot-swe-agent[bot] 385867e82b Fix race condition in WireGuard server with concurrent peer connections
Add mutex protection to server.go to prevent race condition when multiple
peers connect simultaneously. The shared routingInfo field was being
overwritten by concurrent Process() calls, causing connections to fail.

- Add sync.RWMutex to protect access to routing info
- Only update routing info if not already set or dispatcher changed
- Use local copy of routing info in forwardConnection to avoid races
- Existing tests pass

Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
2026-01-09 10:28:10 +00:00

207 lines
5.5 KiB
Go

package wireguard
import (
"context"
goerrors "errors"
"io"
"sync"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport/internet/stat"
)
var nullDestination = net.TCPDestination(net.AnyIP, 0)
type Server struct {
bindServer *netBindServer
// Use a mutex-protected default routing info for forwarded connections
// Since we cannot determine which peer initiated a forwarded connection from gvisor,
// we use the most recently set routing info as default
infoMutex sync.RWMutex
info routingInfo
policyManager policy.Manager
}
type routingInfo struct {
ctx context.Context
dispatcher routing.Dispatcher
inboundTag *session.Inbound
contentTag *session.Content
}
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
v := core.MustFromContext(ctx)
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
if err != nil {
return nil, err
}
server := &Server{
bindServer: &netBindServer{
netBind: netBind{
dns: v.GetFeature(dns.ClientType()).(dns.Client),
dnsOption: dns.IPOption{
IPv4Enable: hasIPv4,
IPv6Enable: hasIPv6,
},
},
},
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
}
tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
if err != nil {
return nil, err
}
if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
_ = tun.Close()
return nil, err
}
return server, nil
}
// Network implements proxy.Inbound.
func (*Server) Network() []net.Network {
return []net.Network{net.Network_UDP}
}
// Process implements proxy.Inbound.
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
// Use RWMutex to safely handle concurrent access to routing info
// Only update if not set or if dispatcher is different
s.infoMutex.Lock()
if s.info.dispatcher == nil || s.info.dispatcher != dispatcher {
s.info = routingInfo{
ctx: ctx,
dispatcher: dispatcher,
inboundTag: session.InboundFromContext(ctx),
contentTag: session.ContentFromContext(ctx),
}
}
s.infoMutex.Unlock()
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
if err != nil {
return err
}
nep := ep.(*netEndpoint)
nep.conn = conn
reader := buf.NewPacketReader(conn)
for {
mpayload, err := reader.ReadMultiBuffer()
if err != nil {
return err
}
for _, payload := range mpayload {
v, ok := <-s.bindServer.readQueue
if !ok {
return nil
}
i, err := payload.Read(v.buff)
v.bytes = i
v.endpoint = nep
v.err = err
v.waiter.Done()
if err != nil && goerrors.Is(err, io.EOF) {
nep.conn = nil
return nil
}
}
}
}
func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
// Safely read routing info
s.infoMutex.RLock()
info := s.info
s.infoMutex.RUnlock()
if info.dispatcher == nil {
errors.LogError(info.ctx, "unexpected: dispatcher == nil")
return
}
defer conn.Close()
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(info.ctx))
sid := session.NewID()
ctx = c.ContextWithID(ctx, sid)
inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
if info.inboundTag != nil {
inbound = *info.inboundTag
}
inbound.Name = "wireguard"
inbound.CanSpliceCopy = 3
// overwrite the source to use the tun address for each sub context.
// Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
// Currently we have no way to link to the original source address
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
ctx = session.ContextWithInbound(ctx, &inbound)
if info.contentTag != nil {
ctx = session.ContextWithContent(ctx, info.contentTag)
}
ctx = session.SubContextFromMuxInbound(ctx)
plcy := s.policyManager.ForLevel(0)
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: nullDestination,
To: dest,
Status: log.AccessAccepted,
Reason: "",
})
link, err := info.dispatcher.Dispatch(ctx, dest)
if err != nil {
errors.LogErrorInner(ctx, err, "dispatch connection")
}
defer cancel()
requestDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport all TCP request").Base(err)
}
return nil
}
responseDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport all TCP response").Base(err)
}
return nil
}
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
errors.LogDebugInner(ctx, err, "connection ends")
return
}
}