mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-07-04 10:48:49 +00:00
WireGuard proxy: Refactor (#6287)
And https://github.com/XTLS/Xray-core/pull/6303#issuecomment-4669158076 Fixes https://github.com/XTLS/Xray-core/issues/6257
This commit is contained in:
+200
-112
@@ -2,6 +2,10 @@ package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
c "github.com/xtls/xray-core/common/ctx"
|
||||
@@ -10,162 +14,246 @@ import (
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/features/stats"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
var nullDestination = net.TCPDestination(net.AnyIP, 0)
|
||||
|
||||
type Server struct {
|
||||
bindServer *netBindServer
|
||||
|
||||
info routingInfo
|
||||
conf *DeviceConfig
|
||||
ctx context.Context
|
||||
policyManager policy.Manager
|
||||
}
|
||||
dispatcher routing.Dispatcher
|
||||
|
||||
type routingInfo struct {
|
||||
ctx context.Context
|
||||
dispatcher routing.Dispatcher
|
||||
inboundTag *session.Inbound
|
||||
contentTag *session.Content
|
||||
tag string
|
||||
src net.Destination
|
||||
sniffingRequest session.SniffingRequest
|
||||
streamSettings *internet.MemoryStreamConfig
|
||||
uplinkCounter stats.Counter
|
||||
downlinkCounter stats.Counter
|
||||
|
||||
tun tun.Device
|
||||
stack *stack.Stack
|
||||
dev *device.Device
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||
v := core.MustFromContext(ctx)
|
||||
p := v.GetFeature(policy.ManagerType()).(policy.Manager)
|
||||
d := v.GetFeature(routing.DispatcherType()).(routing.Dispatcher)
|
||||
|
||||
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
|
||||
inbound := session.InboundFromContext(ctx)
|
||||
content := session.ContentFromContext(ctx)
|
||||
streamSettings := session.StreamSettingsFromContext(ctx).(*internet.MemoryStreamConfig)
|
||||
tag := inbound.Tag
|
||||
var uplinkCounter stats.Counter
|
||||
var downlinkCounter stats.Counter
|
||||
if len(tag) > 0 && p.ForSystem().Stats.InboundUplink {
|
||||
statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
|
||||
name := "inbound>>>" + tag + ">>>traffic>>>uplink"
|
||||
c, _ := stats.GetOrRegisterCounter(statsManager, name)
|
||||
if c != nil {
|
||||
uplinkCounter = c
|
||||
}
|
||||
}
|
||||
if len(tag) > 0 && p.ForSystem().Stats.InboundDownlink {
|
||||
statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
|
||||
name := "inbound>>>" + tag + ">>>traffic>>>downlink"
|
||||
c, _ := stats.GetOrRegisterCounter(statsManager, name)
|
||||
if c != nil {
|
||||
downlinkCounter = c
|
||||
}
|
||||
}
|
||||
|
||||
if len(conf.Peers) == 0 {
|
||||
return nil, errors.New("empty peers")
|
||||
}
|
||||
for _, peer := range conf.Peers {
|
||||
if peer.PublicKey == "" {
|
||||
return nil, errors.New("peer without publickey")
|
||||
}
|
||||
}
|
||||
|
||||
localAddresses := make([]netip.Addr, 0, len(conf.Endpoint))
|
||||
for _, localaddress := range conf.Endpoint {
|
||||
addr, err := netip.ParseAddr(localaddress)
|
||||
if err == nil {
|
||||
localAddresses = append(localAddresses, addr)
|
||||
continue
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(localaddress)
|
||||
if err == nil {
|
||||
localAddresses = append(localAddresses, prefix.Addr())
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tun, _, stack, err := CreateNetTUN(localAddresses, nil, int(conf.Mtu), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
bindServer: &netBindServer{
|
||||
netBind: netBind{
|
||||
dns: v.GetFeature(dns.ClientType()).(dns.Client),
|
||||
dnsOption: dns.IPOption{
|
||||
IPv4Enable: hasIPv4,
|
||||
IPv6Enable: hasIPv6,
|
||||
},
|
||||
workers: int(conf.NumWorkers),
|
||||
readQueue: make(chan *netReadInfo),
|
||||
},
|
||||
},
|
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||
}
|
||||
return &Server{
|
||||
conf: conf,
|
||||
ctx: core.ToBackgroundDetachedContext(ctx),
|
||||
policyManager: p,
|
||||
dispatcher: d,
|
||||
|
||||
tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tag: inbound.Tag,
|
||||
src: inbound.Source,
|
||||
sniffingRequest: content.SniffingRequest,
|
||||
streamSettings: streamSettings,
|
||||
uplinkCounter: uplinkCounter,
|
||||
downlinkCounter: downlinkCounter,
|
||||
|
||||
if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
|
||||
_ = tun.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return server, nil
|
||||
tun: tun,
|
||||
stack: stack,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Network implements proxy.Inbound.
|
||||
// Network implements proxy.Inbound.Network.
|
||||
func (*Server) Network() []net.Network {
|
||||
return []net.Network{net.Network_UDP}
|
||||
return []net.Network{}
|
||||
}
|
||||
|
||||
// Process implements proxy.Inbound.
|
||||
// Process implements proxy.Inbound.Process.
|
||||
func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
|
||||
s.info = routingInfo{
|
||||
ctx: ctx,
|
||||
dispatcher: dispatcher,
|
||||
inboundTag: session.InboundFromContext(ctx),
|
||||
contentTag: session.ContentFromContext(ctx),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
|
||||
// Close implements common.Closable.Close.
|
||||
func (s *Server) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.dev != nil {
|
||||
s.dev.Close()
|
||||
s.dev = nil
|
||||
s.tun = nil
|
||||
} else if s.tun != nil {
|
||||
s.tun.Close()
|
||||
s.tun = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start implements common.Runnable.Start.
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.dev != nil {
|
||||
return nil
|
||||
}
|
||||
if s.src.Address.Family().IsDomain() {
|
||||
return errors.New("address is domain")
|
||||
}
|
||||
listenFunc := func() (net.PacketConn, error) {
|
||||
pktConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: s.src.Address.IP(), Port: int(s.src.Port)}, s.streamSettings.SocketSettings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.streamSettings.UdpmaskManager != nil {
|
||||
newConn, err := s.streamSettings.UdpmaskManager.WrapPacketConnServer(pktConn)
|
||||
if err != nil {
|
||||
pktConn.Close()
|
||||
return nil, errors.New("mask err").Base(err)
|
||||
}
|
||||
pktConn = newConn
|
||||
}
|
||||
if s.uplinkCounter != nil || s.downlinkCounter != nil {
|
||||
pktConn = &PacketCounterConnection{
|
||||
PacketConn: pktConn,
|
||||
ReadCounter: s.uplinkCounter,
|
||||
WriteCounter: s.downlinkCounter,
|
||||
}
|
||||
}
|
||||
return pktConn, nil
|
||||
}
|
||||
bind := &bind{
|
||||
listenFunc: listenFunc,
|
||||
}
|
||||
logger := &device.Logger{
|
||||
Verbosef: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Debug,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
Errorf: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Error,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
}
|
||||
dev := device.NewDevice(s.tun, bind, logger)
|
||||
var cfg strings.Builder
|
||||
cfg.WriteString("private_key=" + s.conf.SecretKey + "\n")
|
||||
for _, peer := range s.conf.Peers {
|
||||
cfg.WriteString("public_key=" + peer.PublicKey + "\n")
|
||||
if peer.PreSharedKey != "" {
|
||||
cfg.WriteString("preshared_key=" + peer.PreSharedKey + "\n")
|
||||
}
|
||||
for _, ip := range peer.AllowedIps {
|
||||
cfg.WriteString("allowed_ip=" + ip + "\n")
|
||||
}
|
||||
if peer.KeepAlive != "" {
|
||||
cfg.WriteString("persistent_keepalive_interval=" + peer.KeepAlive + "\n")
|
||||
}
|
||||
}
|
||||
err := dev.IpcSet(cfg.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nep := ep.(*netEndpoint)
|
||||
nep.conn = conn
|
||||
|
||||
reader := buf.NewPacketReader(conn)
|
||||
for {
|
||||
mb, err := reader.ReadMultiBuffer()
|
||||
if err != nil {
|
||||
nep.conn = nil
|
||||
buf.ReleaseMulti(mb)
|
||||
return err
|
||||
}
|
||||
|
||||
for i, b := range mb {
|
||||
|
||||
rawBytes := b.Bytes()
|
||||
if b.Len() > 3 {
|
||||
rawBytes[1] = 0
|
||||
rawBytes[2] = 0
|
||||
rawBytes[3] = 0
|
||||
}
|
||||
|
||||
select {
|
||||
case s.bindServer.readQueue <- &netReadInfo{
|
||||
buff: b,
|
||||
endpoint: nep,
|
||||
}:
|
||||
case <-s.bindServer.closedCh:
|
||||
nep.conn = nil
|
||||
buf.ReleaseMulti(mb[i:])
|
||||
return errors.New("bind closed")
|
||||
}
|
||||
}
|
||||
err = dev.Up()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.dev = dev
|
||||
createForwarder(s.stack, s.HandleConnection)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
||||
if s.info.dispatcher == nil {
|
||||
errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
|
||||
return
|
||||
func (s *Server) HandleConnection(conn net.Conn, dest net.Destination) {
|
||||
defer conn.Close()
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
defer cancel()
|
||||
ctx = c.ContextWithID(ctx, session.NewID())
|
||||
|
||||
source := net.DestinationFromAddr(conn.RemoteAddr())
|
||||
inbound := session.Inbound{
|
||||
Name: "wireguard",
|
||||
Tag: s.tag,
|
||||
CanSpliceCopy: 3,
|
||||
Source: source,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
|
||||
sid := session.NewID()
|
||||
ctx = c.ContextWithID(ctx, sid)
|
||||
inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
|
||||
if s.info.inboundTag != nil {
|
||||
inbound = *s.info.inboundTag
|
||||
}
|
||||
inbound.Name = "wireguard"
|
||||
inbound.CanSpliceCopy = 3
|
||||
|
||||
// overwrite the source to use the tun address for each sub context.
|
||||
// Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
|
||||
// Currently we have no way to link to the original source address
|
||||
inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
|
||||
ctx = session.ContextWithInbound(ctx, &inbound)
|
||||
content := new(session.Content)
|
||||
if s.info.contentTag != nil {
|
||||
content.SniffingRequest = s.info.contentTag.SniffingRequest
|
||||
}
|
||||
ctx = session.ContextWithContent(ctx, content)
|
||||
ctx = session.ContextWithContent(ctx, &session.Content{
|
||||
SniffingRequest: s.sniffingRequest,
|
||||
})
|
||||
ctx = session.SubContextFromMuxInbound(ctx)
|
||||
|
||||
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||
From: nullDestination,
|
||||
From: inbound.Source,
|
||||
To: dest,
|
||||
Status: log.AccessAccepted,
|
||||
Reason: "",
|
||||
})
|
||||
errors.LogInfo(ctx, "processing from ", source, " to ", dest)
|
||||
|
||||
err := s.info.dispatcher.DispatchLink(ctx, dest, &transport.Link{
|
||||
Reader: buf.NewReader(conn),
|
||||
link := &transport.Link{
|
||||
Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)},
|
||||
Writer: buf.NewWriter(conn),
|
||||
})
|
||||
if err != nil {
|
||||
errors.LogInfoInner(ctx, err, "connection ends")
|
||||
}
|
||||
|
||||
cancel()
|
||||
conn.Close()
|
||||
if err := s.dispatcher.DispatchLink(ctx, dest, link); err != nil {
|
||||
errors.LogError(ctx, errors.New("connection closed").Base(err))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user