diff --git a/app/proxyman/inbound/always.go b/app/proxyman/inbound/always.go index f26a30c2e..462f894f5 100644 --- a/app/proxyman/inbound/always.go +++ b/app/proxyman/inbound/always.go @@ -57,16 +57,23 @@ func NewAlwaysOnInboundHandler(ctx context.Context, tag string, receiverConfig * if err != nil { return nil, err } - - // Set tag and sniffing config in context before creating proxy - // This allows proxies like TUN to access these settings - ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: tag}) - if receiverConfig.SniffingSettings != nil { - ctx = session.ContextWithContent(ctx, &session.Content{ - SniffingRequest: sniffingRequest, - }) + src := net.TCPDestination(net.AnyIP, 0) + if receiverConfig.Listen != nil { + src.Address = receiverConfig.Listen.AsAddress() } - rawProxy, err := common.CreateObject(ctx, proxyConfig) + if receiverConfig.PortList != nil && len(receiverConfig.PortList.Range) > 0 { + src.Port = net.Port(receiverConfig.PortList.Range[0].From) + } + mss, err := internet.ToMemoryStreamConfig(receiverConfig.StreamSettings) + if err != nil { + return nil, errors.New("failed to parse stream config").Base(err).AtWarning() + } + + newCtx := session.ContextWithInbound(ctx, &session.Inbound{Tag: tag, Source: src}) + newCtx = session.ContextWithContent(newCtx, &session.Content{SniffingRequest: sniffingRequest}) + newCtx = session.ContextWithStreamSettings(newCtx, mss) + + rawProxy, err := common.CreateObject(newCtx, proxyConfig) if err != nil { return nil, err } @@ -92,11 +99,6 @@ func NewAlwaysOnInboundHandler(ctx context.Context, tag string, receiverConfig * address = net.AnyIP } - mss, err := internet.ToMemoryStreamConfig(receiverConfig.StreamSettings) - if err != nil { - return nil, errors.New("failed to parse stream config").Base(err).AtWarning() - } - if receiverConfig.ReceiveOriginalDestination { if mss.SocketSettings == nil { mss.SocketSettings = &internet.SocketConfig{} diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 578b2ca16..2fe0fec55 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -108,7 +108,9 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou ctx = session.ContextWithFullHandler(ctx, h) - rawProxyHandler, err := common.CreateObject(ctx, proxyConfig) + newCtx := session.ContextWithStreamSettings(ctx, h.streamSettings) + + rawProxyHandler, err := common.CreateObject(newCtx, proxyConfig) if err != nil { return nil, err } diff --git a/common/session/context.go b/common/session/context.go index c28f20816..abe4219ca 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -26,6 +26,8 @@ const ( fullHandlerKey ctx.SessionKey = 10 // outbound gets full handler mitmAlpn11Key ctx.SessionKey = 11 // used by TLS dialer mitmServerNameKey ctx.SessionKey = 12 // used by TLS dialer + + streamSettingsKey ctx.SessionKey = 13 ) func ContextWithInbound(ctx context.Context, inbound *Inbound) context.Context { @@ -192,3 +194,11 @@ func MitmServerNameFromContext(ctx context.Context) string { } return "" } + +func ContextWithStreamSettings(ctx context.Context, streamSettings any) context.Context { + return context.WithValue(ctx, streamSettingsKey, streamSettings) +} + +func StreamSettingsFromContext(ctx context.Context) any { + return ctx.Value(streamSettingsKey) +} diff --git a/infra/conf/wireguard.go b/infra/conf/wireguard.go index 1ba61f96b..d985184f4 100644 --- a/infra/conf/wireguard.go +++ b/infra/conf/wireguard.go @@ -3,6 +3,7 @@ package conf import ( "encoding/base64" "encoding/hex" + "strconv" "strings" "github.com/xtls/xray-core/common/errors" @@ -37,8 +38,9 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) { } config.Endpoint = c.Endpoint - // default 0 - config.KeepAlive = c.KeepAlive + if c.KeepAlive != 0 { + config.KeepAlive = strconv.FormatUint(uint64(c.KeepAlive), 10) + } if c.AllowedIPs == nil { config.AllowedIps = []string{"0.0.0.0/0", "::0/0"} } else { @@ -56,7 +58,6 @@ type WireGuardConfig struct { Address []string `json:"address"` Peers []*WireGuardPeerConfig `json:"peers"` MTU int32 `json:"mtu"` - NumWorkers int32 `json:"workers"` Reserved []byte `json:"reserved"` DomainStrategy string `json:"domainStrategy"` } @@ -93,9 +94,6 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { } else { config.Mtu = c.MTU } - // these a fallback code exists in wireguard-go code, - // we don't need to process fallback manually - config.NumWorkers = c.NumWorkers if len(c.Reserved) != 0 && len(c.Reserved) != 3 { return nil, errors.New(`"reserved" should be empty or 3 bytes`) diff --git a/infra/conf/wireguard_test.go b/infra/conf/wireguard_test.go index c4c24c44a..2e1903db8 100644 --- a/infra/conf/wireguard_test.go +++ b/infra/conf/wireguard_test.go @@ -38,12 +38,10 @@ func TestWireGuardConfig(t *testing.T) { // also can read from hex form directly PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a", Endpoint: "127.0.0.1:1234", - KeepAlive: 0, AllowedIps: []string{"0.0.0.0/0", "::0/0"}, }, }, Mtu: 1300, - NumWorkers: 2, DomainStrategy: wireguard.DeviceConfig_FORCE_IP64, NoKernelTun: false, }, diff --git a/proxy/hysteria/client.go b/proxy/hysteria/client.go index 7d5730b6b..5b602c342 100644 --- a/proxy/hysteria/client.go +++ b/proxy/hysteria/client.go @@ -29,6 +29,13 @@ type Client struct { } func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { + v := core.MustFromContext(ctx) + p := v.GetFeature(policy.ManagerType()).(policy.Manager) + + streamSettings := session.StreamSettingsFromContext(ctx).(*internet.MemoryStreamConfig) + if _, ok := streamSettings.ProtocolSettings.(*hysteria.Config); !ok { + return nil, errors.New("not hysteria transport") + } if config.Server == nil { return nil, errors.New(`no target server found`) } @@ -37,12 +44,10 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { return nil, errors.New("failed to get server spec").Base(err) } - v := core.MustFromContext(ctx) - client := &Client{ + return &Client{ server: server, - policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), - } - return client, nil + policyManager: p, + }, nil } func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { diff --git a/proxy/hysteria/server.go b/proxy/hysteria/server.go index 815faca11..d7456dc30 100644 --- a/proxy/hysteria/server.go +++ b/proxy/hysteria/server.go @@ -16,6 +16,7 @@ import ( "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/proxy/hysteria/account" "github.com/xtls/xray-core/transport" + "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/hysteria" "github.com/xtls/xray-core/transport/internet/stat" ) @@ -27,6 +28,14 @@ type Server struct { } func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { + v := core.MustFromContext(ctx) + p := v.GetFeature(policy.ManagerType()).(policy.Manager) + + streamSettings := session.StreamSettingsFromContext(ctx).(*internet.MemoryStreamConfig) + if _, ok := streamSettings.ProtocolSettings.(*hysteria.Config); !ok { + return nil, errors.New("not hysteria transport") + } + validator := account.NewValidator() for _, user := range config.Users { u, err := user.ToMemoryUser() @@ -39,14 +48,11 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { } } - v := core.MustFromContext(ctx) - s := &Server{ + return &Server{ config: config, validator: validator, - policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), - } - - return s, nil + policyManager: p, + }, nil } func (s *Server) HysteriaInboundValidator() *account.Validator { diff --git a/proxy/wireguard/bind.go b/proxy/wireguard/bind.go index c5c0f8c8e..0e71e1539 100644 --- a/proxy/wireguard/bind.go +++ b/proxy/wireguard/bind.go @@ -2,265 +2,150 @@ package wireguard import ( "context" - gonet "net" + goerrors "errors" + "io" + "net" "net/netip" - "runtime" "strconv" + "sync" + "syscall" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - - "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/features/dns" - "github.com/xtls/xray-core/transport/internet" + "golang.zx2c4.com/wireguard/conn" ) -type netReadInfo struct { - buff *buf.Buffer - endpoint conn.Endpoint +type bind struct { + resolveFunc func(host string) (net.IP, error) + listenFunc func() (net.PacketConn, error) + downFunc func() error + reserved []byte + + net.PacketConn + closeCh chan struct{} + mu sync.Mutex } -// reduce duplicated code -type netBind struct { - dns dns.Client - dnsOption dns.IPOption +func (b *bind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + b.mu.Lock() + defer b.mu.Unlock() - workers int - readQueue chan *netReadInfo - closedCh chan struct{} + if b.PacketConn != nil { + return nil, 0, conn.ErrBindAlreadyOpen + } + + c, err := b.listenFunc() + if err != nil { + return nil, 0, err + } + b.PacketConn = c + ch := make(chan struct{}) + b.closeCh = ch + + return []conn.ReceiveFunc{ + func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + for { + n, addr, err := c.ReadFrom(bufs[0]) + if err != nil { + if goerrors.Is(err, io.EOF) || goerrors.Is(err, io.ErrClosedPipe) || goerrors.Is(err, net.ErrClosed) { + select { + case <-ch: + default: + errors.LogErrorInner(context.Background(), err, "unexpected closed") + if b.downFunc != nil { + go func() { + common.Must(b.downFunc()) + }() + } + } + return 0, net.ErrClosed + } + errors.LogErrorInner(context.Background(), err, "bind recv err") + continue + } + if n > 3 { + bufs[0][1] = 0 + bufs[0][2] = 0 + bufs[0][3] = 0 + } + sizes[0] = n + eps[0] = &conn.StdNetEndpoint{AddrPort: addr.(*net.UDPAddr).AddrPort()} + return 1, nil + } + }, + }, uint16(c.LocalAddr().(*net.UDPAddr).Port), nil } -// SetMark implements conn.Bind -func (bind *netBind) SetMark(mark uint32) error { +func (b *bind) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + if b.PacketConn != nil { + close(b.closeCh) + _ = b.PacketConn.Close() + b.PacketConn = nil + } return nil } -// ParseEndpoint implements conn.Bind -func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) { - ipStr, port, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - portNum, err := strconv.Atoi(port) - if err != nil { - return nil, err +func (b *bind) SetMark(mark uint32) error { + return nil +} + +func (b *bind) Send(bufs [][]byte, ep conn.Endpoint) (err error) { + b.mu.Lock() + c := b.PacketConn + b.mu.Unlock() + + if c == nil { + return syscall.EAFNOSUPPORT } - addr := net.ParseAddress(ipStr) - if addr.Family() == net.AddressFamilyDomain { - ips, _, err := n.dns.LookupIP(addr.Domain(), n.dnsOption) + for i := range bufs { + if len(bufs[i]) > 3 && len(b.reserved) == 3 { + bufs[i][1] = b.reserved[0] + bufs[i][2] = b.reserved[1] + bufs[i][3] = b.reserved[2] + } + _, err = c.WriteTo(bufs[i], net.UDPAddrFromAddrPort(ep.(*conn.StdNetEndpoint).AddrPort)) if err != nil { - return nil, err - } else if len(ips) == 0 { - return nil, dns.ErrEmptyResponse - } - addr = net.IPAddress(ips[0]) - } - - dst := net.Destination{ - Address: addr, - Port: net.Port(portNum), - Network: net.Network_UDP, - } - - return &netEndpoint{ - dst: dst, - }, nil -} - -// BatchSize implements conn.Bind -func (bind *netBind) BatchSize() int { - return 1 -} - -// Open implements conn.Bind -func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { - bind.closedCh = make(chan struct{}) - errors.LogDebug(context.Background(), "bind opened") - - fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { - select { - case r := <-bind.readQueue: - sizes[0], eps[0] = copy(bufs[0], r.buff.Bytes()), r.endpoint - r.buff.Release() - return 1, nil - case <-bind.closedCh: - errors.LogDebug(context.Background(), "recv func closed") - return 0, gonet.ErrClosed + errors.LogErrorInner(context.Background(), err, "bind send err") + break } } - workers := bind.workers - if workers <= 0 { - workers = runtime.NumCPU() - } - if workers <= 0 { - workers = 1 - } - arr := make([]conn.ReceiveFunc, workers) - for i := 0; i < workers; i++ { - arr[i] = fun - } - - return arr, uint16(uport), nil -} - -// Close implements conn.Bind -func (bind *netBind) Close() error { - errors.LogDebug(context.Background(), "bind closed") - if bind.closedCh != nil { - close(bind.closedCh) - } - return nil -} - -type netBindClient struct { - netBind - - ctx context.Context - dialer internet.Dialer - reserved []byte -} - -func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { - c, err := bind.dialer.Dial(bind.ctx, endpoint.dst) - if err != nil { - return err - } - endpoint.conn = c - - go func() { - for { - buff := buf.NewWithSize(device.MaxMessageSize) - n, err := buff.ReadFrom(c) - if err != nil { - buff.Release() - endpoint.conn = nil - c.Close() - return - } - - rawBytes := buff.Bytes() - if n > 3 { - rawBytes[1] = 0 - rawBytes[2] = 0 - rawBytes[3] = 0 - } - - select { - case bind.readQueue <- &netReadInfo{ - buff: buff, - endpoint: endpoint, - }: - case <-bind.closedCh: - buff.Release() - endpoint.conn = nil - c.Close() - return - } - } - }() - - return nil -} - -func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error { - var err error - - nend, ok := endpoint.(*netEndpoint) - if !ok { - return conn.ErrWrongEndpointType - } - - if nend.conn == nil { - err = bind.connectTo(nend) - if err != nil { - return err - } - } - - for _, buff := range buff { - if len(buff) > 3 && len(bind.reserved) == 3 { - copy(buff[1:], bind.reserved) - } - if _, err = nend.conn.Write(buff); err != nil { - return err - } - } - return nil -} - -type netBindServer struct { - netBind -} - -func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error { - var err error - - nend, ok := endpoint.(*netEndpoint) - if !ok { - return conn.ErrWrongEndpointType - } - - if nend.conn == nil { - errors.LogDebug(context.Background(), nend.dst.NetAddr(), " send on closed peer") - return errors.New("peer closed") - } - - for _, buff := range buff { - if _, err = nend.conn.Write(buff); err != nil { - return err - } - } - return err } -type netEndpoint struct { - dst net.Destination - conn net.Conn -} - -func (netEndpoint) ClearSrc() {} - -func (e netEndpoint) DstIP() netip.Addr { - return netip.Addr{} -} - -func (e netEndpoint) SrcIP() netip.Addr { - return netip.Addr{} -} - -func (e netEndpoint) DstToBytes() []byte { - var dat []byte - if e.dst.Address.Family().IsIPv4() { - dat = e.dst.Address.IP().To4()[:] - } else { - dat = e.dst.Address.IP().To16()[:] - } - dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8)) - return dat -} - -func (e netEndpoint) DstToString() string { - return e.dst.NetAddr() -} - -func (e netEndpoint) SrcToString() string { - return "" -} - -func toNetIpAddr(addr net.Address) netip.Addr { - if addr.Family().IsIPv4() { - ip := addr.IP() - return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]}) - } else { - ip := addr.IP() - arr := [16]byte{} - for i := 0; i < 16; i++ { - arr[i] = ip[i] +func (b *bind) ParseEndpoint(s string) (conn.Endpoint, error) { + if b.resolveFunc == nil { + e, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err } - return netip.AddrFrom16(arr) + return &conn.StdNetEndpoint{ + AddrPort: e, + }, nil } + host, sport, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(sport) + if err != nil { + return nil, err + } + if port < 0 || port > 65535 { + return nil, errors.New("invalid port " + sport) + } + ip, err := b.resolveFunc(host) + if err != nil { + return nil, err + } + addr, _ := netip.AddrFromSlice(ip) + return &conn.StdNetEndpoint{ + AddrPort: netip.AddrPortFrom(addr, uint16(port)), + }, nil +} + +func (b *bind) BatchSize() int { + return 1 } diff --git a/proxy/wireguard/client.go b/proxy/wireguard/client.go index 8a02d89f1..4491886ab 100644 --- a/proxy/wireguard/client.go +++ b/proxy/wireguard/client.go @@ -1,148 +1,135 @@ -/* - -Some of codes are copied from https://github.com/octeep/wireproxy, license below. - -Copyright (c) 2022 Wind T.F. Wong - -Permission to use, copy, modify, and distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -*/ - package wireguard import ( "context" "fmt" + gonet "net" "net/netip" + reflect "reflect" "strings" "sync" + "golang.zx2c4.com/wireguard/tun" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/policy" + "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" + "golang.zx2c4.com/wireguard/device" ) -// Handler is an outbound connection that silently swallow the entire payload. type Handler struct { conf *DeviceConfig - net Tunnel - bind *netBindClient policyManager policy.Manager dns dns.Client - // cached configuration - endpoints []netip.Addr - hasIPv4, hasIPv6 bool - wgLock sync.Mutex + + streamSettings *internet.MemoryStreamConfig + uplinkCounter stats.Counter + downlinkCounter stats.Counter + + tun tun.Device + tnet *Net + dev *device.Device + mu sync.Mutex } -// New creates a new wireguard handler. -func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { +func NewClient(ctx context.Context, conf *DeviceConfig) (*Handler, error) { v := core.MustFromContext(ctx) + p := v.GetFeature(policy.ManagerType()).(policy.Manager) + d := v.GetFeature(dns.ClientType()).(dns.Client) - endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) + streamSettings := session.StreamSettingsFromContext(ctx).(*internet.MemoryStreamConfig) + tag := session.FullHandlerFromContext(ctx).Tag() + var uplinkCounter stats.Counter + var downlinkCounter stats.Counter + if len(tag) > 0 && p.ForSystem().Stats.OutboundUplink { + statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager) + name := "outbound>>>" + tag + ">>>traffic>>>uplink" + c, _ := stats.GetOrRegisterCounter(statsManager, name) + if c != nil { + uplinkCounter = c + } + } + if len(tag) > 0 && p.ForSystem().Stats.OutboundDownlink { + statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager) + name := "outbound>>>" + tag + ">>>traffic>>>downlink" + c, _ := stats.GetOrRegisterCounter(statsManager, name) + if c != nil { + downlinkCounter = c + } + } + + if len(conf.Peers) == 0 { + return nil, errors.New("empty peers") + } + for _, peer := range conf.Peers { + if peer.PublicKey == "" { + return nil, errors.New("peer without publickey") + } + if peer.Endpoint == "" { + return nil, errors.New("peer without endpoint") + } + } + + localAddresses := make([]netip.Addr, 0, len(conf.Endpoint)) + for _, localaddress := range conf.Endpoint { + addr, err := netip.ParseAddr(localaddress) + if err == nil { + localAddresses = append(localAddresses, addr) + continue + } + prefix, err := netip.ParsePrefix(localaddress) + if err == nil { + localAddresses = append(localAddresses, prefix.Addr()) + continue + } + return nil, err + } + + kernelTunSupported, err := KernelTunSupported() + if err != nil { + errors.LogWarningInner(context.Background(), err, "Failed to check kernel TUN support") + } + var tun tun.Device + var tnet *Net + if !conf.NoKernelTun && kernelTunSupported { + errors.LogWarning(context.Background(), "Using kernel TUN") + tun, tnet, err = createKernelTun(localAddresses, []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1.0.0.1"), netip.MustParseAddr("2606:4700:4700::1111"), netip.MustParseAddr("2606:4700:4700::1001")}, int(conf.Mtu)) + } else { + errors.LogWarning(context.Background(), "Using gVisor TUN") + tun, tnet, _, err = CreateNetTUN(localAddresses, []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1.0.0.1"), netip.MustParseAddr("2606:4700:4700::1111"), netip.MustParseAddr("2606:4700:4700::1001")}, int(conf.Mtu), true) + } if err != nil { return nil, err } - d := v.GetFeature(dns.ClientType()).(dns.Client) return &Handler{ conf: conf, - policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + policyManager: p, dns: d, - endpoints: endpoints, - hasIPv4: hasIPv4, - hasIPv6: hasIPv6, + + streamSettings: streamSettings, + uplinkCounter: uplinkCounter, + downlinkCounter: downlinkCounter, + + tun: tun, + tnet: tnet, }, nil } -func (h *Handler) Close() (err error) { - go func() { - h.wgLock.Lock() - defer h.wgLock.Unlock() - - if h.net != nil { - _ = h.net.Close() - h.net = nil - } - }() - - return nil -} - -func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer) (err error) { - h.wgLock.Lock() - defer h.wgLock.Unlock() - - if h.bind != nil && h.bind.dialer == dialer && h.net != nil { - return nil - } - - log.Record(&log.GeneralMessage{ - Severity: log.Severity_Info, - Content: "switching dialer", - }) - - if h.net != nil { - _ = h.net.Close() - h.net = nil - } - if h.bind != nil { - _ = h.bind.Close() - h.bind = nil - } - - // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer - h.bind = &netBindClient{ - netBind: netBind{ - dns: h.dns, - dnsOption: dns.IPOption{ - IPv4Enable: h.hasIPv4, - IPv6Enable: h.hasIPv6, - }, - workers: int(h.conf.NumWorkers), - readQueue: make(chan *netReadInfo), - }, - ctx: ctx, - dialer: dialer, - reserved: h.conf.Reserved, - } - defer func() { - if err != nil { - h.bind.Close() - h.bind = nil - } - }() - - h.net, err = h.makeVirtualTun() - if err != nil { - return errors.New("failed to create virtual tun interface").Base(err) - } - return nil -} - -// Process implements OutboundHandler.Dispatch(). +// Process implements proxy.Outbound.Process. func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { outbounds := session.OutboundsFromContext(ctx) ob := outbounds[len(outbounds)-1] @@ -152,40 +139,31 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte ob.Name = "wireguard" ob.CanSpliceCopy = 3 - if err := h.processWireGuard(ctx, dialer); err != nil { + if h.dev == nil { + if err := h.init(ctx); err != nil { + return err + } + } + + if err := h.dev.Up(); err != nil { return err } - // Destination of the inner request. - destination := ob.Target - command := protocol.RequestCommandTCP - if destination.Network == net.Network_UDP { - command = protocol.RequestCommandUDP + var addr netip.Addr + if ob.Target.Address.Family().IsDomain() { + ip, err := h.resolveRemote(ob.Target.Address.String()) + if err != nil { + return errors.New("failed to resolve domain").Base(err) + } + addr, _ = netip.AddrFromSlice(ip) + } else { + addr, _ = netip.AddrFromSlice(ob.Target.Address.IP()) } - // resolve dns - addr := destination.Address - if addr.Family().IsDomain() { - ips, _, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ - IPv4Enable: h.hasIPv4 && h.conf.preferIP4(), - IPv6Enable: h.hasIPv6 && h.conf.preferIP6(), - }) - { // Resolve fallback - if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { - ips, _, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ - IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(), - IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(), - }) - } - } - if err != nil { - return errors.New("failed to lookup DNS").Base(err) - } else if len(ips) == 0 { - return dns.ErrEmptyResponse - } - addr = net.IPAddress(ips[dice.Roll(len(ips))]) + addrPort := netip.AddrPortFrom(addr, ob.Target.Port.Value()) + if !addrPort.IsValid() { + return errors.New("invalid target ", ob.Target) } - destination.Address = addr var newCtx context.Context var newCancel context.CancelFunc @@ -193,59 +171,64 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte newCtx, newCancel = context.WithCancel(context.Background()) } - p := h.policyManager.ForLevel(0) - + sessionPolicy := h.policyManager.ForLevel(0) ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, func() { cancel() if newCancel != nil { newCancel() } - }, p.Timeouts.ConnectionIdle) - addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) + }, sessionPolicy.Timeouts.ConnectionIdle) - var requestFunc func() error - var responseFunc func() error + if newCtx != nil { + ctx = newCtx + } - if command == protocol.RequestCommandTCP { - conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) + var reader buf.Reader + var writer buf.Writer + + switch ob.Target.Network { + case net.Network_TCP: + var conn net.Conn + var err error + if sessionPolicy.Timeouts.Handshake != 0 { + timeoutCtx, timeoutCancel := context.WithTimeout(ctx, sessionPolicy.Timeouts.Handshake) + conn, err = h.tnet.DialContextTCPAddrPort(timeoutCtx, addrPort) + timeoutCancel() + } else { + conn, err = h.tnet.DialContextTCPAddrPort(ctx, addrPort) + } if err != nil { return errors.New("failed to create TCP connection").Base(err) } defer conn.Close() - - requestFunc = func() error { - defer timer.SetTimeout(p.Timeouts.DownlinkOnly) - return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) - } - responseFunc = func() error { - defer timer.SetTimeout(p.Timeouts.UplinkOnly) - return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) - } - } else if command == protocol.RequestCommandUDP { - conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) + reader = buf.NewReader(conn) + writer = buf.NewWriter(conn) + case net.Network_UDP: + conn, err := h.tnet.DialUDPAddrPort(netip.AddrPort{}, addrPort) if err != nil { return errors.New("failed to create UDP connection").Base(err) } defer conn.Close() - - conn = &udpConnClient{ - Conn: conn, - dest: destination, - } - - requestFunc = func() error { - defer timer.SetTimeout(p.Timeouts.DownlinkOnly) - return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) - } - responseFunc = func() error { - defer timer.SetTimeout(p.Timeouts.UplinkOnly) - return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) + c := &udpConnClient{ + PacketConn: conn.(*internet.PacketConnWrapper).PacketConn, + resolveFunc: h.resolveRemote, + dest: gonet.UDPAddrFromAddrPort(addrPort), } + reader = c + writer = c + default: + panic(ob.Target.Network) } - if newCtx != nil { - ctx = newCtx + requestFunc := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) + return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)) + } + + responseFunc := func() error { + defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) + return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)) } responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) @@ -258,108 +241,191 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return nil } -// creates a tun interface on netstack given a configuration -func (h *Handler) makeVirtualTun() (Tunnel, error) { - t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil) +func (h *Handler) Close() (err error) { + h.mu.Lock() + defer h.mu.Unlock() + if h.dev != nil { + h.dev.Close() + h.dev = nil + h.tun = nil + } else if h.tun != nil { + h.tun.Close() + h.tun = nil + } + return nil +} + +func (h *Handler) init(ctx context.Context) error { + h.mu.Lock() + defer h.mu.Unlock() + if h.dev != nil { + return nil + } + resolveFunc := h.resolveLocal + listenFunc := func() (net.PacketConn, error) { + dest, err := net.ParseDestination("udp:" + h.conf.Peers[0].Endpoint) + if err != nil { + return nil, err + } + conn, err := internet.DialSystem(ctx, dest, h.streamSettings.SocketSettings) + if err != nil { + return nil, err + } + var pktConn net.PacketConn + switch c := conn.(type) { + case *internet.PacketConnWrapper: + pktConn = c.PacketConn + case *cnc.Connection: + pktConn = &internet.FakePacketConn{Conn: c} + default: + panic(reflect.TypeOf(c)) + } + if h.streamSettings.UdpmaskManager != nil { + newConn, err := h.streamSettings.UdpmaskManager.WrapPacketConnClient(pktConn) + if err != nil { + pktConn.Close() + return nil, errors.New("mask err").Base(err) + } + pktConn = newConn + } + if h.uplinkCounter != nil || h.downlinkCounter != nil { + pktConn = &PacketCounterConnection{ + PacketConn: pktConn, + ReadCounter: h.downlinkCounter, + WriteCounter: h.uplinkCounter, + } + } + return pktConn, nil + } + bind := &bind{} + logger := &device.Logger{ + Verbosef: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Debug, + Content: fmt.Sprintf(format, args...), + }) + }, + Errorf: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Error, + Content: fmt.Sprintf(format, args...), + }) + }, + } + dev := device.NewDevice(h.tun, bind, logger) + bind.resolveFunc = resolveFunc + bind.listenFunc = listenFunc + bind.downFunc = dev.Down + bind.reserved = h.conf.Reserved + var cfg strings.Builder + cfg.WriteString("private_key=" + h.conf.SecretKey + "\n") + for _, peer := range h.conf.Peers { + cfg.WriteString("public_key=" + peer.PublicKey + "\n") + if peer.PreSharedKey != "" { + cfg.WriteString("preshared_key=" + peer.PreSharedKey + "\n") + } + cfg.WriteString("endpoint=" + peer.Endpoint + "\n") + for _, ip := range peer.AllowedIps { + cfg.WriteString("allowed_ip=" + ip + "\n") + } + if peer.KeepAlive != "" { + cfg.WriteString("persistent_keepalive_interval=" + peer.KeepAlive + "\n") + } + } + err := dev.IpcSet(cfg.String()) + if err != nil { + return err + } + err = dev.Up() + if err != nil { + return err + } + h.dev = dev + return nil +} + +func (h *Handler) resolveLocal(host string) (net.IP, error) { + return resolveDomain(host, h.conf.DomainStrategy, func(host string) ([]net.IP, error) { + ips, _, err := h.dns.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: true}) + return ips, err + }) +} + +func (h *Handler) resolveRemote(host string) (net.IP, error) { + return resolveDomain(host, h.conf.DomainStrategy, func(host string) ([]net.IP, error) { + addrs, err := h.tnet.LookupHost(host) + if err != nil { + return nil, err + } + ips := make([]net.IP, 0, len(addrs)) + for _, addr := range addrs { + ips = append(ips, net.ParseIP(addr)) + } + return ips, nil + }) +} + +func resolveDomain(host string, strategy DeviceConfig_DomainStrategy, lookupIP func(host string) ([]net.IP, error)) (net.IP, error) { + if ip := net.ParseIP(host); ip != nil { + return ip, nil + } + ips, err := lookupIP(host) if err != nil { return nil, err } - - h.bind.dnsOption.IPv4Enable = h.hasIPv4 - h.bind.dnsOption.IPv6Enable = h.hasIPv6 - - if err = t.BuildDevice(h.createIPCRequest(), h.bind); err != nil { - _ = t.Close() - return nil, err + if len(ips) == 0 { + return nil, dns.ErrEmptyResponse } - return t, nil -} - -// serialize the config into an IPC request -func (h *Handler) createIPCRequest() string { - var request strings.Builder - - request.WriteString(fmt.Sprintf("private_key=%s\n", h.conf.SecretKey)) - - if !h.conf.IsClient { - // placeholder, we'll handle actual port listening on Xray - request.WriteString("listen_port=1337\n") - } - - for _, peer := range h.conf.Peers { - if peer.PublicKey != "" { - request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey)) - } - - if peer.PreSharedKey != "" { - request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey)) - } - - address, port, err := net.SplitHostPort(peer.Endpoint) - if err != nil { - errors.LogError(h.bind.ctx, "failed to split endpoint ", peer.Endpoint, " into address and port") - } - addr := net.ParseAddress(address) - if addr.Family().IsDomain() { - dialerIp := h.bind.dialer.DestIpAddress() - if dialerIp != nil { - addr = net.ParseAddress(dialerIp.String()) - errors.LogInfo(h.bind.ctx, "createIPCRequest use dialer dest ip: ", addr) - } else { - ips, _, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ - IPv4Enable: h.conf.preferIP4(), - IPv6Enable: h.conf.preferIP6(), - }) - { // Resolve fallback - if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { - ips, _, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ - IPv4Enable: h.conf.fallbackIP4(), - IPv6Enable: h.conf.fallbackIP6(), - }) - } - } - if err != nil { - errors.LogInfoInner(h.bind.ctx, err, "createIPCRequest failed to lookup DNS") - } else if len(ips) == 0 { - errors.LogInfo(h.bind.ctx, "createIPCRequest empty lookup DNS") - } else { - addr = net.IPAddress(ips[dice.Roll(len(ips))]) - } - } - } - - if peer.Endpoint != "" { - request.WriteString(fmt.Sprintf("endpoint=%s:%s\n", addr, port)) - } - - for _, ip := range peer.AllowedIps { - request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) - } - - if peer.KeepAlive != 0 { - request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive)) + var got4, got6 []net.IP + for _, ip := range ips { + if ip.To4() != nil { + got4 = append(got4, ip) + } else { + got6 = append(got6, ip) } } - - return request.String()[:request.Len()] + var got []net.IP + switch strategy { + case DeviceConfig_FORCE_IP: + got = ips + return ips[dice.Roll(len(ips))], nil + case DeviceConfig_FORCE_IP4: + got = got4 + case DeviceConfig_FORCE_IP6: + got = got6 + case DeviceConfig_FORCE_IP46: + got = got4 + if len(got) == 0 { + got = got6 + } + case DeviceConfig_FORCE_IP64: + got = got6 + if len(got) == 0 { + got = got4 + } + default: + panic(strategy) + } + if len(got) == 0 { + return nil, dns.ErrEmptyResponse + } + return got[dice.Roll(len(got))], nil } type udpConnClient struct { - net.Conn - dest net.Destination + net.PacketConn + resolveFunc func(host string) (net.IP, error) + dest *net.UDPAddr } func (c *udpConnClient) ReadMultiBuffer() (buf.MultiBuffer, error) { b := buf.New() b.Resize(0, buf.Size) - n, addr, err := c.Conn.(net.PacketConn).ReadFrom(b.Bytes()) + n, addr, err := c.PacketConn.ReadFrom(b.Bytes()) if err != nil { b.Release() return nil, err } - if addr == nil { // should never hit - addr = c.dest.RawNetAddr() - } b.Resize(0, int32(n)) b.UDP = &net.Destination{ @@ -375,9 +441,22 @@ func (c *udpConnClient) WriteMultiBuffer(mb buf.MultiBuffer) error { for i, b := range mb { dst := c.dest if b.UDP != nil { - dst = *b.UDP + if b.UDP.Address.Family().IsDomain() { + ip, err := c.resolveFunc(b.UDP.Address.String()) + if err != nil { + errors.LogErrorInner(context.Background(), err, "drop packet to ", b.UDP, " with size ", len(b.Bytes())) + b.Release() + continue + } + dst = &net.UDPAddr{ + IP: ip, + Port: int(b.UDP.Port), + } + } else { + dst = b.UDP.RawNetAddr().(*net.UDPAddr) + } } - _, err := c.Conn.(net.PacketConn).WriteTo(b.Bytes(), dst.RawNetAddr()) + _, err := c.PacketConn.WriteTo(b.Bytes(), dst) if err != nil { buf.ReleaseMulti(mb[i:]) return err @@ -386,3 +465,25 @@ func (c *udpConnClient) WriteMultiBuffer(mb buf.MultiBuffer) error { } return nil } + +type PacketCounterConnection struct { + net.PacketConn + ReadCounter stats.Counter + WriteCounter stats.Counter +} + +func (c *PacketCounterConnection) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(p) + if err == nil && c.ReadCounter != nil { + c.ReadCounter.Add(int64(n)) + } + return +} + +func (c *PacketCounterConnection) WriteTo(p []byte, addr net.Addr) (n int, err error) { + n, err = c.PacketConn.WriteTo(p, addr) + if err == nil && c.WriteCounter != nil { + c.WriteCounter.Add(int64(n)) + } + return +} diff --git a/proxy/wireguard/config.go b/proxy/wireguard/config.go index cbaa670b7..0b280a906 100644 --- a/proxy/wireguard/config.go +++ b/proxy/wireguard/config.go @@ -1,54 +1 @@ package wireguard - -import ( - "context" - - "github.com/xtls/xray-core/common/errors" -) - -func (c *DeviceConfig) preferIP4() bool { - return c.DomainStrategy == DeviceConfig_FORCE_IP || - c.DomainStrategy == DeviceConfig_FORCE_IP4 || - c.DomainStrategy == DeviceConfig_FORCE_IP46 -} - -func (c *DeviceConfig) preferIP6() bool { - return c.DomainStrategy == DeviceConfig_FORCE_IP || - c.DomainStrategy == DeviceConfig_FORCE_IP6 || - c.DomainStrategy == DeviceConfig_FORCE_IP64 -} - -func (c *DeviceConfig) hasFallback() bool { - return c.DomainStrategy == DeviceConfig_FORCE_IP46 || c.DomainStrategy == DeviceConfig_FORCE_IP64 -} - -func (c *DeviceConfig) fallbackIP4() bool { - return c.DomainStrategy == DeviceConfig_FORCE_IP64 -} - -func (c *DeviceConfig) fallbackIP6() bool { - return c.DomainStrategy == DeviceConfig_FORCE_IP46 -} - -func (c *DeviceConfig) createTun() tunCreator { - if !c.IsClient { - // See tun_linux.go createKernelTun() - errors.LogWarning(context.Background(), "Using gVisor TUN. WG inbound doesn't support kernel TUN yet.") - return createGVisorTun - } - if c.NoKernelTun { - errors.LogWarning(context.Background(), "Using gVisor TUN. NoKernelTun is set to true.") - return createGVisorTun - } - kernelTunSupported, err := KernelTunSupported() - if err != nil { - errors.LogWarning(context.Background(), "Using gVisor TUN. Failed to check kernel TUN support:", err) - return createGVisorTun - } - if !kernelTunSupported { - errors.LogWarning(context.Background(), "Using gVisor TUN. Kernel TUN is not supported on your OS, or your permission is insufficient.") - return createGVisorTun - } - errors.LogWarning(context.Background(), "Using kernel TUN.") - return createKernelTun -} diff --git a/proxy/wireguard/config.pb.go b/proxy/wireguard/config.pb.go index 16134c8d6..5a5ab288b 100644 --- a/proxy/wireguard/config.pb.go +++ b/proxy/wireguard/config.pb.go @@ -81,7 +81,7 @@ type PeerConfig struct { PublicKey string `protobuf:"bytes,1,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"` PreSharedKey string `protobuf:"bytes,2,opt,name=pre_shared_key,json=preSharedKey,proto3" json:"pre_shared_key,omitempty"` Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"` - KeepAlive uint32 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"` + KeepAlive string `protobuf:"bytes,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"` AllowedIps []string `protobuf:"bytes,5,rep,name=allowed_ips,json=allowedIps,proto3" json:"allowed_ips,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -138,11 +138,11 @@ func (x *PeerConfig) GetEndpoint() string { return "" } -func (x *PeerConfig) GetKeepAlive() uint32 { +func (x *PeerConfig) GetKeepAlive() string { if x != nil { return x.KeepAlive } - return 0 + return "" } func (x *PeerConfig) GetAllowedIps() []string { @@ -158,7 +158,6 @@ type DeviceConfig struct { Endpoint []string `protobuf:"bytes,2,rep,name=endpoint,proto3" json:"endpoint,omitempty"` Peers []*PeerConfig `protobuf:"bytes,3,rep,name=peers,proto3" json:"peers,omitempty"` Mtu int32 `protobuf:"varint,4,opt,name=mtu,proto3" json:"mtu,omitempty"` - NumWorkers int32 `protobuf:"varint,5,opt,name=num_workers,json=numWorkers,proto3" json:"num_workers,omitempty"` Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"` DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"` IsClient bool `protobuf:"varint,8,opt,name=is_client,json=isClient,proto3" json:"is_client,omitempty"` @@ -225,13 +224,6 @@ func (x *DeviceConfig) GetMtu() int32 { return 0 } -func (x *DeviceConfig) GetNumWorkers() int32 { - if x != nil { - return x.NumWorkers - } - return 0 -} - func (x *DeviceConfig) GetReserved() []byte { if x != nil { return x.Reserved @@ -272,17 +264,15 @@ const file_proxy_wireguard_config_proto_rawDesc = "" + "\x0epre_shared_key\x18\x02 \x01(\tR\fpreSharedKey\x12\x1a\n" + "\bendpoint\x18\x03 \x01(\tR\bendpoint\x12\x1d\n" + "\n" + - "keep_alive\x18\x04 \x01(\rR\tkeepAlive\x12\x1f\n" + + "keep_alive\x18\x04 \x01(\tR\tkeepAlive\x12\x1f\n" + "\vallowed_ips\x18\x05 \x03(\tR\n" + - "allowedIps\"\xcb\x03\n" + + "allowedIps\"\xaa\x03\n" + "\fDeviceConfig\x12\x1d\n" + "\n" + "secret_key\x18\x01 \x01(\tR\tsecretKey\x12\x1a\n" + "\bendpoint\x18\x02 \x03(\tR\bendpoint\x126\n" + "\x05peers\x18\x03 \x03(\v2 .xray.proxy.wireguard.PeerConfigR\x05peers\x12\x10\n" + - "\x03mtu\x18\x04 \x01(\x05R\x03mtu\x12\x1f\n" + - "\vnum_workers\x18\x05 \x01(\x05R\n" + - "numWorkers\x12\x1a\n" + + "\x03mtu\x18\x04 \x01(\x05R\x03mtu\x12\x1a\n" + "\breserved\x18\x06 \x01(\fR\breserved\x12Z\n" + "\x0fdomain_strategy\x18\a \x01(\x0e21.xray.proxy.wireguard.DeviceConfig.DomainStrategyR\x0edomainStrategy\x12\x1b\n" + "\tis_client\x18\b \x01(\bR\bisClient\x12\"\n" + diff --git a/proxy/wireguard/config.proto b/proxy/wireguard/config.proto index aa05b822b..95291279b 100644 --- a/proxy/wireguard/config.proto +++ b/proxy/wireguard/config.proto @@ -10,7 +10,7 @@ message PeerConfig { string public_key = 1; string pre_shared_key = 2; string endpoint = 3; - uint32 keep_alive = 4; + string keep_alive = 4; repeated string allowed_ips = 5; } @@ -26,7 +26,7 @@ message DeviceConfig { repeated string endpoint = 2; repeated PeerConfig peers = 3; int32 mtu = 4; - int32 num_workers = 5; + bytes reserved = 6; DomainStrategy domain_strategy = 7; bool is_client = 8; diff --git a/proxy/wireguard/gvisortun/tun.go b/proxy/wireguard/gvisortun/tun.go deleted file mode 100644 index 379fad424..000000000 --- a/proxy/wireguard/gvisortun/tun.go +++ /dev/null @@ -1,226 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. - */ - -package gvisortun - -import ( - "context" - "fmt" - "net/netip" - "os" - "sync" - "syscall" - - "golang.zx2c4.com/wireguard/tun" - "gvisor.dev/gvisor/pkg/buffer" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" - "gvisor.dev/gvisor/pkg/tcpip/transport/udp" -) - -type netTun struct { - ep *channel.Endpoint - stack *stack.Stack - events chan tun.Event - notifyHandle *channel.NotificationHandle - incomingPacket chan *buffer.View - mtu int - hasV4, hasV6 bool - closeOnce sync.Once -} - -type Net netTun - -func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, - HandleLocal: !promiscuousMode, - } - dev := &netTun{ - ep: channel.New(1024, uint32(mtu), ""), - stack: stack.New(opts), - events: make(chan tun.Event, 10), - incomingPacket: make(chan *buffer.View), - mtu: mtu, - } - sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default - tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) - if tcpipErr != nil { - return nil, nil, dev.stack, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) - } - dev.notifyHandle = dev.ep.AddNotify(dev) - tcpipErr = dev.stack.CreateNIC(1, dev.ep) - if tcpipErr != nil { - return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr) - } - for _, ip := range localAddresses { - var protoNumber tcpip.NetworkProtocolNumber - if ip.Is4() { - protoNumber = ipv4.ProtocolNumber - } else if ip.Is6() { - protoNumber = ipv6.ProtocolNumber - } - protoAddr := tcpip.ProtocolAddress{ - Protocol: protoNumber, - AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) - } - if ip.Is4() { - dev.hasV4 = true - } else if ip.Is6() { - dev.hasV6 = true - } - } - if dev.hasV4 { - dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) - } - if dev.hasV6 { - dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) - } - if promiscuousMode { - // enable promiscuous mode to handle all packets processed by netstack - dev.stack.SetPromiscuousMode(1, true) - dev.stack.SetSpoofing(1, true) - } - - dev.events <- tun.EventUp - return dev, (*Net)(dev), dev.stack, nil -} - -// Name implements tun.Device -func (tun *netTun) Name() (string, error) { - return "go", nil -} - -// File implements tun.Device -func (tun *netTun) File() *os.File { - return nil -} - -// Events implements tun.Device -func (tun *netTun) Events() <-chan tun.Event { - return tun.events -} - -// Read implements tun.Device -func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { - view, ok := <-tun.incomingPacket - if !ok { - return 0, os.ErrClosed - } - - n, err := view.Read(buf[0][offset:]) - if err != nil { - return 0, err - } - sizes[0] = n - return 1, nil -} - -// Write implements tun.Device -func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { - for _, buf := range buf { - packet := buf[offset:] - if len(packet) == 0 { - continue - } - - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) - switch packet[0] >> 4 { - case 4: - tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) - case 6: - tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) - default: - return 0, syscall.EAFNOSUPPORT - } - } - return len(buf), nil -} - -// WriteNotify implements channel.Notification -func (tun *netTun) WriteNotify() { - pkt := tun.ep.Read() - if pkt == nil { - return - } - - view := pkt.ToView() - pkt.DecRef() - - tun.incomingPacket <- view -} - -// Close implements tun.Device -func (tun *netTun) Close() error { - tun.closeOnce.Do(func() { - tun.stack.RemoveNIC(1) - tun.stack.Close() - tun.ep.RemoveNotify(tun.notifyHandle) - tun.ep.Close() - - close(tun.events) - - close(tun.incomingPacket) - }) - return nil -} - -// MTU implements tun.Device -func (tun *netTun) MTU() (int, error) { - return tun.mtu, nil -} - -// BatchSize implements tun.Device -func (tun *netTun) BatchSize() int { - return 1 -} - -func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - var protoNumber tcpip.NetworkProtocolNumber - if endpoint.Addr().Is4() { - protoNumber = ipv4.ProtocolNumber - } else { - protoNumber = ipv6.ProtocolNumber - } - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), - Port: endpoint.Port(), - }, protoNumber -} - -func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { - fa, pn := convertToFullAddr(addr) - return gonet.DialContextTCP(ctx, net.stack, fa, pn) -} - -func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { - var lfa, rfa *tcpip.FullAddress - var pn tcpip.NetworkProtocolNumber - if laddr.IsValid() || laddr.Port() > 0 { - var addr tcpip.FullAddress - addr, pn = convertToFullAddr(laddr) - lfa = &addr - } - if raddr.IsValid() || raddr.Port() > 0 { - var addr tcpip.FullAddress - addr, pn = convertToFullAddr(raddr) - rfa = &addr - rfa = nil // do not ep connect - } - return gonet.DialUDP(net.stack, lfa, rfa, pn) -} diff --git a/proxy/wireguard/netstack.go b/proxy/wireguard/netstack.go new file mode 100644 index 000000000..22c42e3f9 --- /dev/null +++ b/proxy/wireguard/netstack.go @@ -0,0 +1,690 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. + */ + +package wireguard + +import ( + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "os" + "strings" + "syscall" + "time" + + "github.com/xtls/xray-core/transport/internet" + "golang.zx2c4.com/wireguard/tun" + + "golang.org/x/net/dns/dnsmessage" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +type netTun struct { + ep *channel.Endpoint + stack *stack.Stack + events chan tun.Event + notifyHandle *channel.NotificationHandle + incomingPacket chan *buffer.View + mtu int + dnsServers []netip.Addr + hasV4, hasV6 bool +} + +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int, handleLocal bool) (tun.Device, *Net, *stack.Stack, error) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: handleLocal, + } + dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(opts), + events: make(chan tun.Event, 10), + incomingPacket: make(chan *buffer.View), + dnsServers: dnsServers, + mtu: mtu, + } + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default + tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return nil, nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + } + dev.notifyHandle = dev.ep.AddNotify(dev) + tcpipErr = dev.stack.CreateNIC(1, dev.ep) + if tcpipErr != nil { + return nil, nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) + } + for _, ip := range localAddresses { + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { + dev.hasV4 = true + } else if ip.Is6() { + dev.hasV6 = true + } + } + if dev.hasV4 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) + } + if dev.hasV6 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) + } + + tnet := &Net{ + DialContextTCPAddrPort: dev.DialContextTCPAddrPort, + DialUDPAddrPort: dev.DialUDPAddrPort, + dnsServers: dev.dnsServers, + hasV4: dev.hasV4, + hasV6: dev.hasV6, + } + + dev.events <- tun.EventUp + return dev, tnet, dev.stack, nil +} + +func (tun *netTun) Name() (string, error) { + return "go", nil +} + +func (tun *netTun) File() *os.File { + return nil +} + +func (tun *netTun) Events() <-chan tun.Event { + return tun.events +} + +func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { + view, ok := <-tun.incomingPacket + if !ok { + return 0, os.ErrClosed + } + + n, err := view.Read(buf[0][offset:]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { + for _, buf := range buf { + packet := buf[offset:] + if len(packet) == 0 { + continue + } + + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) + switch packet[0] >> 4 { + case 4: + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + default: + return 0, syscall.EAFNOSUPPORT + } + } + return len(buf), nil +} + +func (tun *netTun) WriteNotify() { + pkt := tun.ep.Read() + if pkt == nil { + return + } + + view := pkt.ToView() + pkt.DecRef() + + tun.incomingPacket <- view +} + +func (tun *netTun) Close() error { + tun.stack.RemoveNIC(1) + tun.stack.Close() + tun.ep.RemoveNotify(tun.notifyHandle) + tun.ep.Close() + + if tun.events != nil { + close(tun.events) + } + + if tun.incomingPacket != nil { + close(tun.incomingPacket) + } + + return nil +} + +func (tun *netTun) MTU() (int, error) { + return tun.mtu, nil +} + +func (tun *netTun) BatchSize() int { + return 1 +} + +func (tun *netTun) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, tun.stack, fa, pn) +} + +func (tun *netTun) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { + var pn tcpip.NetworkProtocolNumber = ipv6.ProtocolNumber + if raddr.IsValid() || raddr.Port() > 0 { + _, pn = convertToFullAddr(raddr) + } + conn, err := gonet.DialUDP(tun.stack, nil, nil, pn) + if err != nil { + return nil, err + } + return &internet.PacketConnWrapper{ + PacketConn: conn, + Dest: net.UDPAddrFromAddrPort(raddr), + }, nil +} + +type Net struct { + DialContextTCPAddrPort func(ctx context.Context, addr netip.AddrPort) (net.Conn, error) + DialUDPAddrPort func(laddr, raddr netip.AddrPort) (net.Conn, error) + dnsServers []netip.Addr + hasV4, hasV6 bool +} + +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber + } else { + protoNumber = ipv6.ProtocolNumber + } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +var ( + errNoSuchHost = errors.New("no such host") + errLameReferral = errors.New("lame referral") + errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") + errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") + errServerMisbehaving = errors.New("server misbehaving") + errInvalidDNSResponse = errors.New("invalid DNS response") + errNoAnswerFromDNSServer = errors.New("no answer from DNS server") + errServerTemporarilyMisbehaving = errors.New("server misbehaving") + errCanceled = errors.New("operation was canceled") + errTimeout = errors.New("i/o timeout") +) + +func (net *Net) LookupHost(host string) (addrs []string, err error) { + return net.LookupContextHost(context.Background(), host) +} + +func isDomainName(s string) bool { + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { + return false + } + last := byte('.') + nonNumeric := false + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + nonNumeric = true + partlen++ + case '0' <= c && c <= '9': + partlen++ + case c == '-': + if last == '.' { + return false + } + partlen++ + nonNumeric = true + case c == '.': + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + return nonNumeric +} + +func randU16() uint16 { + var b [2]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return binary.LittleEndian.Uint16(b[:]) +} + +func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { + id = randU16() + b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + return 0, nil, nil, err + } + if err := b.Question(q); err != nil { + return 0, nil, nil, err + } + tcpReq, err = b.Finish() + udpReq = tcpReq[2:] + l := len(tcpReq) - 2 + tcpReq[0] = byte(l >> 8) + tcpReq[1] = byte(l) + return id, udpReq, tcpReq, err +} + +func equalASCIIName(x, y dnsmessage.Name) bool { + if x.Length != y.Length { + return false + } + for i := 0; i < int(x.Length); i++ { + a := x.Data[i] + b := y.Data[i] + if 'A' <= a && a <= 'Z' { + a += 0x20 + } + if 'A' <= b && b <= 'Z' { + b += 0x20 + } + if a != b { + return false + } + } + return true +} + +func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { + if !respHdr.Response { + return false + } + if reqID != respHdr.ID { + return false + } + if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { + return false + } + return true +} + +func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 512) + for { + n, err := c.Read(b) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + continue + } + q, err := p.Question() + if err != nil || !checkResponse(id, query, h, q) { + continue + } + return p, h, nil + } +} + +func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 1280) + if _, err := io.ReadFull(c, b[:2]); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + l := int(b[0])<<8 | int(b[1]) + if l > len(b) { + b = make([]byte, l) + } + n, err := io.ReadFull(c, b[:l]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + q, err := p.Question() + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + if !checkResponse(id, query, h, q) { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + return p, h, nil +} + +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { + q.Class = dnsmessage.ClassINET + id, udpReq, tcpReq, err := newRequest(q) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage + } + + for _, useUDP := range []bool{true, false} { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + + var c net.Conn + var err error + if useUDP { + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) + } else { + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) + } + + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + err := c.SetDeadline(d) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + } + var p dnsmessage.Parser + var h dnsmessage.Header + if useUDP { + p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) + } else { + p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) + } + c.Close() + if err != nil { + if err == context.Canceled { + err = errCanceled + } else if err == context.DeadlineExceeded { + err = errTimeout + } + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + if h.Truncated { + continue + } + return p, h, nil + } + return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer +} + +func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { + if h.RCode == dnsmessage.RCodeNameError { + return errNoSuchHost + } + _, err := p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + return errCannotUnmarshalDNSMessage + } + if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { + return errLameReferral + } + if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { + if h.RCode == dnsmessage.RCodeServerFailure { + return errServerTemporarilyMisbehaving + } + return errServerMisbehaving + } + return nil +} + +func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + return errNoSuchHost + } + if err != nil { + return errCannotUnmarshalDNSMessage + } + if h.Type == qtype { + return nil + } + if err := p.SkipAnswer(); err != nil { + return errCannotUnmarshalDNSMessage + } + } +} + +func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { + var lastErr error + + n, err := dnsmessage.NewName(name) + if err != nil { + return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage + } + q := dnsmessage.Question{ + Name: n, + Type: qtype, + Class: dnsmessage.ClassINET, + } + + for i := 0; i < 2; i++ { + for _, server := range tnet.dnsServers { + p, h, err := tnet.exchange(ctx, server, q, time.Second*5) + if err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + dnsErr.IsTimeout = true + } + if _, ok := err.(*net.OpError); ok { + dnsErr.IsTemporary = true + } + lastErr = dnsErr + continue + } + + if err := checkHeader(&p, h); err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errServerTemporarilyMisbehaving { + dnsErr.IsTemporary = true + } + if err == errNoSuchHost { + dnsErr.IsNotFound = true + return p, server.String(), dnsErr + } + lastErr = dnsErr + continue + } + + err = skipToAnswer(&p, qtype) + if err == nil { + return p, server.String(), nil + } + lastErr = &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errNoSuchHost { + lastErr.(*net.DNSError).IsNotFound = true + return p, server.String(), lastErr + } + } + } + return dnsmessage.Parser{}, "", lastErr +} + +func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { + if host == "" || (!tnet.hasV6 && !tnet.hasV4) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + zlen := len(host) + if strings.IndexByte(host, ':') != -1 { + if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { + zlen = zidx + } + } + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil + } + + if !isDomainName(host) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + type result struct { + p dnsmessage.Parser + server string + error + } + var addrsV4, addrsV6 []netip.Addr + lanes := 0 + if tnet.hasV4 { + lanes++ + } + if tnet.hasV6 { + lanes++ + } + lane := make(chan result, lanes) + var lastErr error + if tnet.hasV4 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) + lane <- result{p, server, err} + }() + } + if tnet.hasV6 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) + lane <- result{p, server, err} + }() + } + for l := 0; l < lanes; l++ { + result := <-lane + if result.error != nil { + if lastErr == nil { + lastErr = result.error + } + continue + } + + loop: + for { + h, err := result.p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + } + if err != nil { + break + } + switch h.Type { + case dnsmessage.TypeA: + a, err := result.p.AResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) + + case dnsmessage.TypeAAAA: + aaaa, err := result.p.AAAAResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) + + default: + if err := result.p.SkipAnswer(); err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + continue + } + } + } + // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled + var addrs []netip.Addr + if tnet.hasV6 { + addrs = append(addrsV6, addrsV4...) + } else { + addrs = append(addrsV4, addrsV6...) + } + + if len(addrs) == 0 && lastErr != nil { + return nil, lastErr + } + saddrs := make([]string, 0, len(addrs)) + for _, ip := range addrs { + saddrs = append(saddrs, ip.String()) + } + return saddrs, nil +} diff --git a/proxy/wireguard/server.go b/proxy/wireguard/server.go index 876d749f7..2c634464c 100644 --- a/proxy/wireguard/server.go +++ b/proxy/wireguard/server.go @@ -2,6 +2,10 @@ package wireguard import ( "context" + "fmt" + "net/netip" + "strings" + "sync" "github.com/xtls/xray-core/common/buf" c "github.com/xtls/xray-core/common/ctx" @@ -10,162 +14,246 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/core" - "github.com/xtls/xray-core/features/dns" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/transport" + "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/stat" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -var nullDestination = net.TCPDestination(net.AnyIP, 0) - type Server struct { - bindServer *netBindServer - - info routingInfo + conf *DeviceConfig + ctx context.Context policyManager policy.Manager -} + dispatcher routing.Dispatcher -type routingInfo struct { - ctx context.Context - dispatcher routing.Dispatcher - inboundTag *session.Inbound - contentTag *session.Content + tag string + src net.Destination + sniffingRequest session.SniffingRequest + streamSettings *internet.MemoryStreamConfig + uplinkCounter stats.Counter + downlinkCounter stats.Counter + + tun tun.Device + stack *stack.Stack + dev *device.Device + mu sync.Mutex } func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { v := core.MustFromContext(ctx) + p := v.GetFeature(policy.ManagerType()).(policy.Manager) + d := v.GetFeature(routing.DispatcherType()).(routing.Dispatcher) - endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) + inbound := session.InboundFromContext(ctx) + content := session.ContentFromContext(ctx) + streamSettings := session.StreamSettingsFromContext(ctx).(*internet.MemoryStreamConfig) + tag := inbound.Tag + var uplinkCounter stats.Counter + var downlinkCounter stats.Counter + if len(tag) > 0 && p.ForSystem().Stats.InboundUplink { + statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager) + name := "inbound>>>" + tag + ">>>traffic>>>uplink" + c, _ := stats.GetOrRegisterCounter(statsManager, name) + if c != nil { + uplinkCounter = c + } + } + if len(tag) > 0 && p.ForSystem().Stats.InboundDownlink { + statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager) + name := "inbound>>>" + tag + ">>>traffic>>>downlink" + c, _ := stats.GetOrRegisterCounter(statsManager, name) + if c != nil { + downlinkCounter = c + } + } + + if len(conf.Peers) == 0 { + return nil, errors.New("empty peers") + } + for _, peer := range conf.Peers { + if peer.PublicKey == "" { + return nil, errors.New("peer without publickey") + } + } + + localAddresses := make([]netip.Addr, 0, len(conf.Endpoint)) + for _, localaddress := range conf.Endpoint { + addr, err := netip.ParseAddr(localaddress) + if err == nil { + localAddresses = append(localAddresses, addr) + continue + } + prefix, err := netip.ParsePrefix(localaddress) + if err == nil { + localAddresses = append(localAddresses, prefix.Addr()) + continue + } + return nil, err + } + + tun, _, stack, err := CreateNetTUN(localAddresses, nil, int(conf.Mtu), false) if err != nil { return nil, err } - server := &Server{ - bindServer: &netBindServer{ - netBind: netBind{ - dns: v.GetFeature(dns.ClientType()).(dns.Client), - dnsOption: dns.IPOption{ - IPv4Enable: hasIPv4, - IPv6Enable: hasIPv6, - }, - workers: int(conf.NumWorkers), - readQueue: make(chan *netReadInfo), - }, - }, - policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), - } + return &Server{ + conf: conf, + ctx: core.ToBackgroundDetachedContext(ctx), + policyManager: p, + dispatcher: d, - tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection) - if err != nil { - return nil, err - } + tag: inbound.Tag, + src: inbound.Source, + sniffingRequest: content.SniffingRequest, + streamSettings: streamSettings, + uplinkCounter: uplinkCounter, + downlinkCounter: downlinkCounter, - if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil { - _ = tun.Close() - return nil, err - } - - return server, nil + tun: tun, + stack: stack, + }, nil } -// Network implements proxy.Inbound. +// Network implements proxy.Inbound.Network. func (*Server) Network() []net.Network { - return []net.Network{net.Network_UDP} + return []net.Network{} } -// Process implements proxy.Inbound. +// Process implements proxy.Inbound.Process. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { - s.info = routingInfo{ - ctx: ctx, - dispatcher: dispatcher, - inboundTag: session.InboundFromContext(ctx), - contentTag: session.ContentFromContext(ctx), - } + return nil +} - ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) +// Close implements common.Closable.Close. +func (s *Server) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.dev != nil { + s.dev.Close() + s.dev = nil + s.tun = nil + } else if s.tun != nil { + s.tun.Close() + s.tun = nil + } + return nil +} + +// Start implements common.Runnable.Start. +func (s *Server) Start() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.dev != nil { + return nil + } + if s.src.Address.Family().IsDomain() { + return errors.New("address is domain") + } + listenFunc := func() (net.PacketConn, error) { + pktConn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: s.src.Address.IP(), Port: int(s.src.Port)}, s.streamSettings.SocketSettings) + if err != nil { + return nil, err + } + if s.streamSettings.UdpmaskManager != nil { + newConn, err := s.streamSettings.UdpmaskManager.WrapPacketConnServer(pktConn) + if err != nil { + pktConn.Close() + return nil, errors.New("mask err").Base(err) + } + pktConn = newConn + } + if s.uplinkCounter != nil || s.downlinkCounter != nil { + pktConn = &PacketCounterConnection{ + PacketConn: pktConn, + ReadCounter: s.uplinkCounter, + WriteCounter: s.downlinkCounter, + } + } + return pktConn, nil + } + bind := &bind{ + listenFunc: listenFunc, + } + logger := &device.Logger{ + Verbosef: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Debug, + Content: fmt.Sprintf(format, args...), + }) + }, + Errorf: func(format string, args ...any) { + log.Record(&log.GeneralMessage{ + Severity: log.Severity_Error, + Content: fmt.Sprintf(format, args...), + }) + }, + } + dev := device.NewDevice(s.tun, bind, logger) + var cfg strings.Builder + cfg.WriteString("private_key=" + s.conf.SecretKey + "\n") + for _, peer := range s.conf.Peers { + cfg.WriteString("public_key=" + peer.PublicKey + "\n") + if peer.PreSharedKey != "" { + cfg.WriteString("preshared_key=" + peer.PreSharedKey + "\n") + } + for _, ip := range peer.AllowedIps { + cfg.WriteString("allowed_ip=" + ip + "\n") + } + if peer.KeepAlive != "" { + cfg.WriteString("persistent_keepalive_interval=" + peer.KeepAlive + "\n") + } + } + err := dev.IpcSet(cfg.String()) if err != nil { return err } - - nep := ep.(*netEndpoint) - nep.conn = conn - - reader := buf.NewPacketReader(conn) - for { - mb, err := reader.ReadMultiBuffer() - if err != nil { - nep.conn = nil - buf.ReleaseMulti(mb) - return err - } - - for i, b := range mb { - - rawBytes := b.Bytes() - if b.Len() > 3 { - rawBytes[1] = 0 - rawBytes[2] = 0 - rawBytes[3] = 0 - } - - select { - case s.bindServer.readQueue <- &netReadInfo{ - buff: b, - endpoint: nep, - }: - case <-s.bindServer.closedCh: - nep.conn = nil - buf.ReleaseMulti(mb[i:]) - return errors.New("bind closed") - } - } + err = dev.Up() + if err != nil { + return err } + s.dev = dev + createForwarder(s.stack, s.HandleConnection) + return nil } -func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { - if s.info.dispatcher == nil { - errors.LogError(s.info.ctx, "unexpected: dispatcher == nil") - return +func (s *Server) HandleConnection(conn net.Conn, dest net.Destination) { + defer conn.Close() + ctx, cancel := context.WithCancel(s.ctx) + defer cancel() + ctx = c.ContextWithID(ctx, session.NewID()) + + source := net.DestinationFromAddr(conn.RemoteAddr()) + inbound := session.Inbound{ + Name: "wireguard", + Tag: s.tag, + CanSpliceCopy: 3, + Source: source, } - ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) - sid := session.NewID() - ctx = c.ContextWithID(ctx, sid) - inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs) - if s.info.inboundTag != nil { - inbound = *s.info.inboundTag - } - inbound.Name = "wireguard" - inbound.CanSpliceCopy = 3 - - // overwrite the source to use the tun address for each sub context. - // Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context - // Currently we have no way to link to the original source address - inbound.Source = net.DestinationFromAddr(conn.RemoteAddr()) ctx = session.ContextWithInbound(ctx, &inbound) - content := new(session.Content) - if s.info.contentTag != nil { - content.SniffingRequest = s.info.contentTag.SniffingRequest - } - ctx = session.ContextWithContent(ctx, content) + ctx = session.ContextWithContent(ctx, &session.Content{ + SniffingRequest: s.sniffingRequest, + }) ctx = session.SubContextFromMuxInbound(ctx) ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ - From: nullDestination, + From: inbound.Source, To: dest, Status: log.AccessAccepted, Reason: "", }) + errors.LogInfo(ctx, "processing from ", source, " to ", dest) - err := s.info.dispatcher.DispatchLink(ctx, dest, &transport.Link{ - Reader: buf.NewReader(conn), + link := &transport.Link{ + Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)}, Writer: buf.NewWriter(conn), - }) - if err != nil { - errors.LogInfoInner(ctx, err, "connection ends") } - - cancel() - conn.Close() + if err := s.dispatcher.DispatchLink(ctx, dest, link); err != nil { + errors.LogError(ctx, errors.New("connection closed").Base(err)) + } } diff --git a/proxy/wireguard/server_test.go b/proxy/wireguard/server_test.go deleted file mode 100644 index 1cb4697ce..000000000 --- a/proxy/wireguard/server_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package wireguard_test - -import ( - "context" - "runtime/debug" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/xtls/xray-core/core" - "github.com/xtls/xray-core/proxy/wireguard" -) - -// TestWireGuardServerInitializationError verifies that an error during TUN initialization -// (triggered by an empty SecretKey) in the WireGuard server does not cause a panic and returns an error instead. -func TestWireGuardServerInitializationError(t *testing.T) { - // Create a minimal core instance with default features - config := &core.Config{} - instance, err := core.New(config) - if err != nil { - t.Fatalf("Failed to create core instance: %v", err) - } - // Set the Xray instance in the context - ctx := context.WithValue(context.Background(), core.XrayKey(1), instance) - - // Define the server configuration with an empty SecretKey to trigger error - conf := &wireguard.DeviceConfig{ - IsClient: false, - Endpoint: []string{"10.0.0.1/32"}, - Mtu: 1420, - SecretKey: "", // Empty SecretKey to trigger error - Peers: []*wireguard.PeerConfig{ - { - PublicKey: "some_public_key", - AllowedIps: []string{"10.0.0.2/32"}, - }, - }, - } - - // Use defer to catch any panic and fail the test explicitly - defer func() { - if r := recover(); r != nil { - t.Errorf("TUN initialization panicked: %v", r) - debug.PrintStack() - } - }() - - // Attempt to initialize the WireGuard server - _, err = wireguard.NewServer(ctx, conf) - - // Check that an error is returned - assert.ErrorContains(t, err, "failed to set private_key: hex string does not fit the slice") -} diff --git a/proxy/wireguard/tun.go b/proxy/wireguard/tun.go index 971ce89e0..68bad3ac6 100644 --- a/proxy/wireguard/tun.go +++ b/proxy/wireguard/tun.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "net/netip" "runtime" "strconv" "strings" @@ -13,9 +12,7 @@ import ( "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/proxy/wireguard/gvisortun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -25,77 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" - - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun" ) -type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) - -type promiscuousModeHandler func(dest net.Destination, conn net.Conn) - -type Tunnel interface { - BuildDevice(ipc string, bind conn.Bind) error - DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error) - DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) - Close() error -} - -type tunnel struct { - tun tun.Device - device *device.Device - rw sync.Mutex -} - -func (t *tunnel) BuildDevice(ipc string, bind conn.Bind) (err error) { - t.rw.Lock() - defer t.rw.Unlock() - - if t.device != nil { - return errors.New("device is already initialized") - } - - logger := &device.Logger{ - Verbosef: func(format string, args ...any) { - log.Record(&log.GeneralMessage{ - Severity: log.Severity_Debug, - Content: fmt.Sprintf(format, args...), - }) - }, - Errorf: func(format string, args ...any) { - log.Record(&log.GeneralMessage{ - Severity: log.Severity_Error, - Content: fmt.Sprintf(format, args...), - }) - }, - } - - t.device = device.NewDevice(t.tun, bind, logger) - if err = t.device.IpcSet(ipc); err != nil { - return err - } - if err = t.device.Up(); err != nil { - return err - } - return nil -} - -func (t *tunnel) Close() (err error) { - t.rw.Lock() - defer t.rw.Unlock() - - if t.device == nil { - return nil - } - - t.device.Close() - t.device = nil - err = t.tun.Close() - t.tun = nil - return nil -} - func CalculateInterfaceName(name string) (tunName string) { if runtime.GOOS == "darwin" { tunName = "utun" @@ -121,93 +49,61 @@ func CalculateInterfaceName(name string) (tunName string) { return } -var _ Tunnel = (*gvisorNet)(nil) +func createForwarder(gstack *stack.Stack, handler func(conn net.Conn, dest net.Destination)) { + gstack.SetPromiscuousMode(1, true) + gstack.SetSpoofing(1, true) -type gvisorNet struct { - tunnel - net *gvisortun.Net -} + tcpForwarder := tcp.NewForwarder(gstack, 0, 65535, func(r *tcp.ForwarderRequest) { + go func(r *tcp.ForwarderRequest) { + var wq waiter.Queue + id := r.ID() -func (g *gvisorNet) Close() error { - return g.tunnel.Close() -} - -func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( - net.Conn, error, -) { - return g.net.DialContextTCPAddrPort(ctx, addr) -} - -func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { - return g.net.DialUDPAddrPort(laddr, raddr) -} - -func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) { - out := &gvisorNet{} - tun, n, gstack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil) - if err != nil { - return nil, err - } - - if handler != nil { - // handler is only used for promiscuous mode - // capture all packets and send to handler - - tcpForwarder := tcp.NewForwarder(gstack, 0, 65535, func(r *tcp.ForwarderRequest) { - go func(r *tcp.ForwarderRequest) { - var wq waiter.Queue - id := r.ID() - - ep, err := r.CreateEndpoint(&wq) - if err != nil { - errors.LogError(context.Background(), err.String()) - r.Complete(true) - return - } - - options := ep.SocketOptions() - options.SetKeepAlive(false) - options.SetReuseAddress(true) - options.SetReusePort(true) - - handler(net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep)) - - ep.Close() - r.Complete(false) - }(r) - }) - gstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) - - manager := &udpManager{ - stack: gstack, - handler: handler, - m: make(map[string]*udpConn), - } - - gstack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - data := pkt.Clone().Data().AsRange().ToSlice() - // if len(data) == 0 { - // return false - // } - srcIP := net.IPAddress(id.RemoteAddress.AsSlice()) - dstIP := net.IPAddress(id.LocalAddress.AsSlice()) - if srcIP == nil || dstIP == nil { - panic(id) + ep, err := r.CreateEndpoint(&wq) + if err != nil { + errors.LogError(context.Background(), err.String()) + r.Complete(true) + return } - src := net.UDPDestination(srcIP, net.Port(id.RemotePort)) - dst := net.UDPDestination(dstIP, net.Port(id.LocalPort)) - manager.feed(src, dst, data) - return true - }) + + options := ep.SocketOptions() + options.SetKeepAlive(false) + options.SetReuseAddress(true) + options.SetReusePort(true) + + handler(gonet.NewTCPConn(&wq, ep), net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))) + + ep.Close() + r.Complete(false) + }(r) + }) + gstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + + manager := &udpManager{ + stack: gstack, + handler: handler, + m: make(map[string]*udpConn), } - out.tun, out.net = tun, n - return out, nil + gstack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + data := pkt.Clone().Data().AsRange().ToSlice() + // if len(data) == 0 { + // return false + // } + srcIP := net.IPAddress(id.RemoteAddress.AsSlice()) + dstIP := net.IPAddress(id.LocalAddress.AsSlice()) + if srcIP == nil || dstIP == nil { + panic(id) + } + src := net.UDPDestination(srcIP, net.Port(id.RemotePort)) + dst := net.UDPDestination(dstIP, net.Port(id.LocalPort)) + manager.feed(src, dst, data) + return true + }) } type udpManager struct { stack *stack.Stack - handler func(dest net.Destination, conn net.Conn) + handler func(conn net.Conn, dest net.Destination) m map[string]*udpConn mutex sync.RWMutex } @@ -246,7 +142,7 @@ func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) m.mutex.Unlock() } m.m[src.NetAddr()] = uc - go m.handler(dst, uc) + go m.handler(uc, dst) } select { diff --git a/proxy/wireguard/tun_default.go b/proxy/wireguard/tun_default.go index 50a509444..edad5545d 100644 --- a/proxy/wireguard/tun_default.go +++ b/proxy/wireguard/tun_default.go @@ -1,14 +1,16 @@ -//go:build !linux || android +//go:build !linux package wireguard import ( "errors" "net/netip" + + "golang.zx2c4.com/wireguard/tun" ) -func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) { - return nil, errors.New("not implemented") +func createKernelTun([]netip.Addr, []netip.Addr, int) (tdev tun.Device, tnet *Net, err error) { + return nil, nil, errors.New("not implemented") } func KernelTunSupported() (bool, error) { diff --git a/proxy/wireguard/tun_linux.go b/proxy/wireguard/tun_linux.go index b8a742e11..eb0175476 100644 --- a/proxy/wireguard/tun_linux.go +++ b/proxy/wireguard/tun_linux.go @@ -1,4 +1,4 @@ -//go:build linux && !android +//go:build linux package wireguard @@ -20,17 +20,6 @@ import ( "golang.zx2c4.com/wireguard/tun" ) -type deviceNet struct { - tunnel - dialer *net.Dialer - lc *net.ListenConfig - - handle *netlink.Handle - linkAddrs []netlink.Addr - routes []*netlink.Route - rules []*netlink.Rule -} - var ( tableIndex int = 10230 mu sync.Mutex @@ -48,82 +37,18 @@ func allocateIPv6TableIndex() int { return currentIndex } -func newDeviceNet(interfaceName string) *deviceNet { - dialer := &net.Dialer{} - dialer.Control = func(network, address string, c syscall.RawConn) error { - return c.Control(func(fd uintptr) { - if err := syscall.BindToDevice(int(fd), interfaceName); err != nil { - errors.LogInfoInner(context.Background(), err, "failed to bind to device") - } - }) - } - lc := &net.ListenConfig{} - lc.Control = func(network, address string, c syscall.RawConn) error { - return c.Control(func(fd uintptr) { - if err := syscall.BindToDevice(int(fd), interfaceName); err != nil { - errors.LogInfoInner(context.Background(), err, "failed to bind to device") - } - }) - } - return &deviceNet{dialer: dialer, lc: lc} +type kernelTun struct { + tun.Device + + dialer *net.Dialer + lc *net.ListenConfig + handle *netlink.Handle + linkAddrs []netlink.Addr + routes []*netlink.Route + rules []*netlink.Rule } -func (d *deviceNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( - net.Conn, error, -) { - return d.dialer.DialContext(ctx, "tcp", addr.String()) -} - -func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { - var conn net.PacketConn - var err error - if raddr.Addr().Is4() { - conn, err = d.lc.ListenPacket(context.Background(), "udp", "0.0.0.0:0") - } else { - conn, err = d.lc.ListenPacket(context.Background(), "udp", "[::]:0") - } - if err != nil { - return nil, err - } - return &internet.PacketConnWrapper{ - PacketConn: conn, - Dest: &net.UDPAddr{ - IP: raddr.Addr().AsSlice(), - Port: int(raddr.Port()), - }, - }, nil -} - -func (d *deviceNet) Close() (err error) { - var errs []error - for _, rule := range d.rules { - if err = d.handle.RuleDel(rule); err != nil { - errs = append(errs, fmt.Errorf("failed to delete rule: %w", err)) - } - } - for _, route := range d.routes { - if err = d.handle.RouteDel(route); err != nil { - errs = append(errs, fmt.Errorf("failed to delete route: %w", err)) - } - } - if err = d.tunnel.Close(); err != nil { - errs = append(errs, fmt.Errorf("failed to close tunnel: %w", err)) - } - if d.handle != nil { - d.handle.Close() - d.handle = nil - } - if len(errs) == 0 { - return nil - } - return goerrors.Join(errs...) -} - -func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) { - if handler != nil { - return nil, errors.New("TODO: support promiscuous mode") - } - +func createKernelTun(localAddresses, dnsServers []netip.Addr, mtu int) (tdev tun.Device, tnet *Net, err error) { var v4, v6 *netip.Addr for _, prefixes := range localAddresses { if v4 == nil && prefixes.Is4() { @@ -150,22 +75,22 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo // system configs. if v4 != nil { if err = writeSysctlZero("/proc/sys/net/ipv4/conf/all/rp_filter"); err != nil { - return nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err) + return nil, nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err) } } if v6 != nil { if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/disable_ipv6"); err != nil { - return nil, fmt.Errorf("failed to enable ipv6: %w", err) + return nil, nil, fmt.Errorf("failed to enable ipv6: %w", err) } if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/rp_filter"); err != nil { - return nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err) + return nil, nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err) } } n := CalculateInterfaceName("wg") wgt, err := tun.CreateTUN(n, mtu) if err != nil { - return nil, err + return nil, nil, err } defer func() { if err != nil { @@ -177,12 +102,12 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo // the operation require root privilege on container require '--privileged' flag. if v4 != nil { if err = writeSysctlZero("/proc/sys/net/ipv4/conf/" + n + "/rp_filter"); err != nil { - return nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err) + return nil, nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err) } } if v6 != nil { if err = writeSysctlZero("/proc/sys/net/ipv6/conf/" + n + "/rp_filter"); err != nil { - return nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err) + return nil, nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err) } } @@ -196,25 +121,28 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo } ipv6TableIndex-- if ipv6TableIndex < 0 { - return nil, fmt.Errorf("failed to find available ipv6 table index") + return nil, nil, fmt.Errorf("failed to find available ipv6 table index") } } } - out := newDeviceNet(n) - out.handle, err = netlink.NewHandle() + t := &kernelTun{ + Device: wgt, + } + + t.handle, err = netlink.NewHandle() if err != nil { - return nil, err + return nil, nil, err } defer func() { if err != nil { - _ = out.Close() + t.Close() } }() l, err := netlink.LinkByName(n) if err != nil { - return nil, err + return nil, nil, err } if v4 != nil { @@ -224,7 +152,7 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo Mask: net.CIDRMask(v4.BitLen(), v4.BitLen()), }, } - out.linkAddrs = append(out.linkAddrs, addr) + t.linkAddrs = append(t.linkAddrs, addr) } if v6 != nil { addr := netlink.Addr{ @@ -233,7 +161,7 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo Mask: net.CIDRMask(v6.BitLen(), v6.BitLen()), }, } - out.linkAddrs = append(out.linkAddrs, addr) + t.linkAddrs = append(t.linkAddrs, addr) rt := &netlink.Route{ LinkIndex: l.Attrs().Index, @@ -243,40 +171,102 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo }, Table: ipv6TableIndex, } - out.routes = append(out.routes, rt) + t.routes = append(t.routes, rt) r := netlink.NewRule() r.Table, r.Family, r.Src = ipv6TableIndex, unix.AF_INET6, addr.IPNet - out.rules = append(out.rules, r) + t.rules = append(t.rules, r) r = netlink.NewRule() r.Table, r.Family, r.OifName = ipv6TableIndex, unix.AF_INET6, n - out.rules = append(out.rules, r) + t.rules = append(t.rules, r) } - for _, addr := range out.linkAddrs { - if err = out.handle.AddrAdd(l, &addr); err != nil { - return nil, fmt.Errorf("failed to add address %s to %s: %w", addr, n, err) + for _, addr := range t.linkAddrs { + if err = t.handle.AddrAdd(l, &addr); err != nil { + return nil, nil, fmt.Errorf("failed to add address %s to %s: %w", addr, n, err) } } - if err = out.handle.LinkSetMTU(l, mtu); err != nil { - return nil, err + if err = t.handle.LinkSetMTU(l, mtu); err != nil { + return nil, nil, err } - if err = out.handle.LinkSetUp(l); err != nil { - return nil, err + if err = t.handle.LinkSetUp(l); err != nil { + return nil, nil, err } - for _, route := range out.routes { - if err = out.handle.RouteAdd(route); err != nil { - return nil, fmt.Errorf("failed to add route %s: %w", route, err) + for _, route := range t.routes { + if err = t.handle.RouteAdd(route); err != nil { + return nil, nil, fmt.Errorf("failed to add route %s: %w", route, err) } } - for _, rule := range out.rules { - if err = out.handle.RuleAdd(rule); err != nil { - return nil, fmt.Errorf("failed to add rule %s: %w", rule, err) + for _, rule := range t.rules { + if err = t.handle.RuleAdd(rule); err != nil { + return nil, nil, fmt.Errorf("failed to add rule %s: %w", rule, err) } } - out.tun = wgt - return out, nil + + dialer := &net.Dialer{} + dialer.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := syscall.BindToDevice(int(fd), n); err != nil { + errors.LogInfoInner(context.Background(), err, "failed to bind to device") + } + }) + } + lc := &net.ListenConfig{} + lc.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if err := syscall.BindToDevice(int(fd), n); err != nil { + errors.LogInfoInner(context.Background(), err, "failed to bind to device") + } + }) + } + t.dialer = dialer + t.lc = lc + + tnet = &Net{ + DialContextTCPAddrPort: t.DialContextTCPAddrPort, + DialUDPAddrPort: t.DialUDPAddrPort, + dnsServers: dnsServers, + hasV4: v4 != nil, + hasV6: v6 != nil, + } + + return t, tnet, nil +} + +func (tun *kernelTun) Close() (err error) { + var errs []error + for _, rule := range tun.rules { + if err = tun.handle.RuleDel(rule); err != nil { + errs = append(errs, fmt.Errorf("failed to delete rule: %w", err)) + } + } + for _, route := range tun.routes { + if err = tun.handle.RouteDel(route); err != nil { + errs = append(errs, fmt.Errorf("failed to delete route: %w", err)) + } + } + if err = tun.Device.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close device: %w", err)) + } + tun.handle.Close() + errs = append(errs, tun.Device.Close()) + return goerrors.Join(errs...) +} + +func (tun *kernelTun) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error) { + return tun.dialer.DialContext(ctx, "tcp", addr.String()) +} + +func (tun *kernelTun) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { + conn, err := tun.lc.ListenPacket(context.Background(), "udp", ":0") + if err != nil { + return nil, err + } + return &internet.PacketConnWrapper{ + PacketConn: conn, + Dest: net.UDPAddrFromAddrPort(raddr), + }, nil } func KernelTunSupported() (bool, error) { diff --git a/proxy/wireguard/wireguard.go b/proxy/wireguard/wireguard.go index 4f489114f..7e9dca029 100644 --- a/proxy/wireguard/wireguard.go +++ b/proxy/wireguard/wireguard.go @@ -2,10 +2,6 @@ package wireguard import ( "context" - "errors" - "fmt" - "net/netip" - "strings" "github.com/xtls/xray-core/common" ) @@ -14,80 +10,9 @@ func init() { common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { deviceConfig := config.(*DeviceConfig) if deviceConfig.IsClient { - return New(ctx, deviceConfig) + return NewClient(ctx, deviceConfig) } else { return NewServer(ctx, deviceConfig) } })) } - -// convert endpoint string to netip.Addr -func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, bool, bool, error) { - var hasIPv4, hasIPv6 bool - - endpoints := make([]netip.Addr, len(conf.Endpoint)) - for i, str := range conf.Endpoint { - var addr netip.Addr - if strings.Contains(str, "/") { - prefix, err := netip.ParsePrefix(str) - if err != nil { - return nil, false, false, err - } - addr = prefix.Addr() - if prefix.Bits() != addr.BitLen() { - return nil, false, false, errors.New("interface address subnet should be /32 for IPv4 and /128 for IPv6") - } - } else { - var err error - addr, err = netip.ParseAddr(str) - if err != nil { - return nil, false, false, err - } - } - endpoints[i] = addr - - if addr.Is4() { - hasIPv4 = true - } else if addr.Is6() { - hasIPv6 = true - } - } - - return endpoints, hasIPv4, hasIPv6, nil -} - -// serialize the config into an IPC request -func createIPCRequest(conf *DeviceConfig) string { - var request strings.Builder - - request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) - - if !conf.IsClient { - // placeholder, we'll handle actual port listening on Xray - request.WriteString("listen_port=1337\n") - } - - for _, peer := range conf.Peers { - if peer.PublicKey != "" { - request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey)) - } - - if peer.PreSharedKey != "" { - request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey)) - } - - if peer.Endpoint != "" { - request.WriteString(fmt.Sprintf("endpoint=%s\n", peer.Endpoint)) - } - - for _, ip := range peer.AllowedIps { - request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) - } - - if peer.KeepAlive != 0 { - request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive)) - } - } - - return request.String()[:request.Len()] -}