mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-05-14 18:09:05 +00:00
WireGuard: Implement UDP FullCone NAT (#5833)
Fixes https://github.com/XTLS/Xray-core/issues/5601 --------- Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
This commit is contained in:
+1
-1
@@ -53,7 +53,7 @@ func GetGlobalID(ctx context.Context) (globalID [8]byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
|
if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP &&
|
||||||
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun") {
|
(inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks" || inbound.Name == "tun" || inbound.Name == "wireguard") {
|
||||||
h := blake3.New(8, BaseKey)
|
h := blake3.New(8, BaseKey)
|
||||||
h.Write([]byte(inbound.Source.String()))
|
h.Write([]byte(inbound.Source.String()))
|
||||||
copy(globalID[:], h.Sum(nil))
|
copy(globalID[:], h.Sum(nil))
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ require (
|
|||||||
golang.org/x/sync v0.20.0
|
golang.org/x/sync v0.20.0
|
||||||
golang.org/x/sys v0.42.0
|
golang.org/x/sys v0.42.0
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||||
google.golang.org/grpc v1.79.3
|
google.golang.org/grpc v1.79.3
|
||||||
google.golang.org/protobuf v1.36.11
|
google.golang.org/protobuf v1.36.11
|
||||||
gvisor.dev/gvisor v0.0.0-20260122175437-89a5d21be8f0
|
gvisor.dev/gvisor v0.0.0-20260122175437-89a5d21be8f0
|
||||||
|
|||||||
@@ -131,6 +131,8 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu
|
|||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ func ParseWireGuardKey(str string) (string, error) {
|
|||||||
return "", errors.New("key must not be empty")
|
return "", errors.New("key must not be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(str)%2 == 0 {
|
if len(str) == 64 {
|
||||||
_, err = hex.DecodeString(str)
|
_, err = hex.DecodeString(str)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return str, nil
|
return str, nil
|
||||||
|
|||||||
@@ -227,6 +227,11 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
|||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn = &udpConnClient{
|
||||||
|
Conn: conn,
|
||||||
|
dest: destination,
|
||||||
|
}
|
||||||
|
|
||||||
requestFunc = func() error {
|
requestFunc = func() error {
|
||||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||||
@@ -336,3 +341,34 @@ func (h *Handler) createIPCRequest() string {
|
|||||||
|
|
||||||
return request.String()[:request.Len()]
|
return request.String()[:request.Len()]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type udpConnClient struct {
|
||||||
|
net.Conn
|
||||||
|
dest net.Destination
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConnClient) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||||
|
b := buf.New()
|
||||||
|
b.Resize(0, buf.Size)
|
||||||
|
n, addr, err := c.Conn.(net.PacketConn).ReadFrom(b.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
b.Release()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if addr == nil { // should never hit
|
||||||
|
addr = c.dest.RawNetAddr()
|
||||||
|
}
|
||||||
|
b.Resize(0, int32(n))
|
||||||
|
|
||||||
|
b.UDP = &net.Destination{
|
||||||
|
Address: net.IPAddress(addr.(*net.UDPAddr).IP),
|
||||||
|
Port: net.Port(addr.(*net.UDPAddr).Port),
|
||||||
|
Network: net.Network_UDP,
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.MultiBuffer{b}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConnClient) Write(p []byte) (int, error) {
|
||||||
|
return c.Conn.(net.PacketConn).WriteTo(p, c.dest.RawNetAddr())
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ type netTun struct {
|
|||||||
ep *channel.Endpoint
|
ep *channel.Endpoint
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
events chan tun.Event
|
events chan tun.Event
|
||||||
|
notifyHandle *channel.NotificationHandle
|
||||||
incomingPacket chan *buffer.View
|
incomingPacket chan *buffer.View
|
||||||
mtu int
|
mtu int
|
||||||
hasV4, hasV6 bool
|
hasV4, hasV6 bool
|
||||||
@@ -48,12 +49,17 @@ func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (t
|
|||||||
dev := &netTun{
|
dev := &netTun{
|
||||||
ep: channel.New(1024, uint32(mtu), ""),
|
ep: channel.New(1024, uint32(mtu), ""),
|
||||||
stack: stack.New(opts),
|
stack: stack.New(opts),
|
||||||
events: make(chan tun.Event, 1),
|
events: make(chan tun.Event, 10),
|
||||||
incomingPacket: make(chan *buffer.View),
|
incomingPacket: make(chan *buffer.View),
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
}
|
}
|
||||||
dev.ep.AddNotify(dev)
|
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
|
||||||
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
|
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
|
||||||
|
if tcpipErr != nil {
|
||||||
|
return nil, nil, dev.stack, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
|
||||||
|
}
|
||||||
|
dev.notifyHandle = dev.ep.AddNotify(dev)
|
||||||
|
tcpipErr = dev.stack.CreateNIC(1, dev.ep)
|
||||||
if tcpipErr != nil {
|
if tcpipErr != nil {
|
||||||
return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
||||||
}
|
}
|
||||||
@@ -90,20 +96,10 @@ func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (t
|
|||||||
dev.stack.SetSpoofing(1, true)
|
dev.stack.SetSpoofing(1, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
opt := tcpip.CongestionControlOption("cubic")
|
|
||||||
if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
|
|
||||||
return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dev.events <- tun.EventUp
|
dev.events <- tun.EventUp
|
||||||
return dev, (*Net)(dev), dev.stack, nil
|
return dev, (*Net)(dev), dev.stack, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchSize implements tun.Device
|
|
||||||
func (tun *netTun) BatchSize() int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name implements tun.Device
|
// Name implements tun.Device
|
||||||
func (tun *netTun) Name() (string, error) {
|
func (tun *netTun) Name() (string, error) {
|
||||||
return "go", nil
|
return "go", nil
|
||||||
@@ -120,7 +116,6 @@ func (tun *netTun) Events() <-chan tun.Event {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read implements tun.Device
|
// Read implements tun.Device
|
||||||
|
|
||||||
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
|
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
|
||||||
view, ok := <-tun.incomingPacket
|
view, ok := <-tun.incomingPacket
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -169,20 +164,16 @@ func (tun *netTun) WriteNotify() {
|
|||||||
tun.incomingPacket <- view
|
tun.incomingPacket <- view
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush implements tun.Device
|
|
||||||
func (tun *netTun) Flush() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close implements tun.Device
|
// Close implements tun.Device
|
||||||
func (tun *netTun) Close() error {
|
func (tun *netTun) Close() error {
|
||||||
tun.closeOnce.Do(func() {
|
tun.closeOnce.Do(func() {
|
||||||
tun.stack.RemoveNIC(1)
|
tun.stack.RemoveNIC(1)
|
||||||
|
tun.stack.Close()
|
||||||
|
tun.ep.RemoveNotify(tun.notifyHandle)
|
||||||
|
tun.ep.Close()
|
||||||
|
|
||||||
close(tun.events)
|
close(tun.events)
|
||||||
|
|
||||||
tun.ep.Close()
|
|
||||||
|
|
||||||
close(tun.incomingPacket)
|
close(tun.incomingPacket)
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
@@ -193,6 +184,11 @@ func (tun *netTun) MTU() (int, error) {
|
|||||||
return tun.mtu, nil
|
return tun.mtu, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchSize implements tun.Device
|
||||||
|
func (tun *netTun) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
||||||
var protoNumber tcpip.NetworkProtocolNumber
|
var protoNumber tcpip.NetworkProtocolNumber
|
||||||
if endpoint.Addr().Is4() {
|
if endpoint.Addr().Is4() {
|
||||||
@@ -224,6 +220,7 @@ func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, er
|
|||||||
var addr tcpip.FullAddress
|
var addr tcpip.FullAddress
|
||||||
addr, pn = convertToFullAddr(raddr)
|
addr, pn = convertToFullAddr(raddr)
|
||||||
rfa = &addr
|
rfa = &addr
|
||||||
|
rfa = nil // do not ep connect
|
||||||
}
|
}
|
||||||
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
||||||
}
|
}
|
||||||
|
|||||||
+13
-39
@@ -5,19 +5,17 @@ import (
|
|||||||
goerrors "errors"
|
goerrors "errors"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/xtls/xray-core/common"
|
|
||||||
"github.com/xtls/xray-core/common/buf"
|
"github.com/xtls/xray-core/common/buf"
|
||||||
c "github.com/xtls/xray-core/common/ctx"
|
c "github.com/xtls/xray-core/common/ctx"
|
||||||
"github.com/xtls/xray-core/common/errors"
|
"github.com/xtls/xray-core/common/errors"
|
||||||
"github.com/xtls/xray-core/common/log"
|
"github.com/xtls/xray-core/common/log"
|
||||||
"github.com/xtls/xray-core/common/net"
|
"github.com/xtls/xray-core/common/net"
|
||||||
"github.com/xtls/xray-core/common/session"
|
"github.com/xtls/xray-core/common/session"
|
||||||
"github.com/xtls/xray-core/common/signal"
|
|
||||||
"github.com/xtls/xray-core/common/task"
|
|
||||||
"github.com/xtls/xray-core/core"
|
"github.com/xtls/xray-core/core"
|
||||||
"github.com/xtls/xray-core/features/dns"
|
"github.com/xtls/xray-core/features/dns"
|
||||||
"github.com/xtls/xray-core/features/policy"
|
"github.com/xtls/xray-core/features/policy"
|
||||||
"github.com/xtls/xray-core/features/routing"
|
"github.com/xtls/xray-core/features/routing"
|
||||||
|
"github.com/xtls/xray-core/transport"
|
||||||
"github.com/xtls/xray-core/transport/internet/stat"
|
"github.com/xtls/xray-core/transport/internet/stat"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,10 +29,10 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type routingInfo struct {
|
type routingInfo struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
dispatcher routing.Dispatcher
|
dispatcher routing.Dispatcher
|
||||||
inboundTag *session.Inbound
|
inboundTag *session.Inbound
|
||||||
contentTag *session.Content
|
contentTag *session.Content
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||||
@@ -124,7 +122,6 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
|||||||
errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
|
errors.LogError(s.info.ctx, "unexpected: dispatcher == nil")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
|
ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
|
||||||
sid := session.NewID()
|
sid := session.NewID()
|
||||||
@@ -146,9 +143,6 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
|||||||
}
|
}
|
||||||
ctx = session.SubContextFromMuxInbound(ctx)
|
ctx = session.SubContextFromMuxInbound(ctx)
|
||||||
|
|
||||||
plcy := s.policyManager.ForLevel(0)
|
|
||||||
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
|
|
||||||
|
|
||||||
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||||
From: nullDestination,
|
From: nullDestination,
|
||||||
To: dest,
|
To: dest,
|
||||||
@@ -156,35 +150,15 @@ func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
|
|||||||
Reason: "",
|
Reason: "",
|
||||||
})
|
})
|
||||||
|
|
||||||
link, err := s.info.dispatcher.Dispatch(ctx, dest)
|
err := s.info.dispatcher.DispatchLink(ctx, dest, &transport.Link{
|
||||||
|
Reader: buf.NewReader(conn),
|
||||||
|
Writer: buf.NewWriter(conn),
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors.LogErrorInner(ctx, err, "dispatch connection")
|
errors.LogInfoInner(ctx, err, "connection ends")
|
||||||
}
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
requestDone := func() error {
|
|
||||||
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
|
|
||||||
if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
|
|
||||||
return errors.New("failed to transport all TCP request").Base(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
responseDone := func() error {
|
cancel()
|
||||||
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
|
conn.Close()
|
||||||
if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
|
|
||||||
return errors.New("failed to transport all TCP response").Base(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
|
|
||||||
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
|
|
||||||
common.Interrupt(link.Reader)
|
|
||||||
common.Interrupt(link.Writer)
|
|
||||||
errors.LogDebugInner(ctx, err, "connection ends")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
+218
-36
@@ -3,6 +3,7 @@ package wireguard
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -10,12 +11,17 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/xtls/xray-core/common/buf"
|
||||||
"github.com/xtls/xray-core/common/errors"
|
"github.com/xtls/xray-core/common/errors"
|
||||||
"github.com/xtls/xray-core/common/log"
|
"github.com/xtls/xray-core/common/log"
|
||||||
"github.com/xtls/xray-core/common/net"
|
"github.com/xtls/xray-core/common/net"
|
||||||
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
|
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
@@ -138,7 +144,7 @@ func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, erro
|
|||||||
|
|
||||||
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
|
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
|
||||||
out := &gvisorNet{}
|
out := &gvisorNet{}
|
||||||
tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
tun, n, gstack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -147,60 +153,236 @@ func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
|||||||
// handler is only used for promiscuous mode
|
// handler is only used for promiscuous mode
|
||||||
// capture all packets and send to handler
|
// capture all packets and send to handler
|
||||||
|
|
||||||
tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
tcpForwarder := tcp.NewForwarder(gstack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
||||||
go func(r *tcp.ForwarderRequest) {
|
go func(r *tcp.ForwarderRequest) {
|
||||||
var (
|
var wq waiter.Queue
|
||||||
wq waiter.Queue
|
var id = r.ID()
|
||||||
id = r.ID()
|
|
||||||
)
|
|
||||||
|
|
||||||
// Perform a TCP three-way handshake.
|
|
||||||
ep, err := r.CreateEndpoint(&wq)
|
ep, err := r.CreateEndpoint(&wq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors.LogError(context.Background(), err.String())
|
errors.LogError(context.Background(), err.String())
|
||||||
r.Complete(true)
|
r.Complete(true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.Complete(false)
|
|
||||||
defer ep.Close()
|
|
||||||
|
|
||||||
// enable tcp keep-alive to prevent hanging connections
|
options := ep.SocketOptions()
|
||||||
ep.SocketOptions().SetKeepAlive(true)
|
options.SetKeepAlive(false)
|
||||||
|
options.SetReuseAddress(true)
|
||||||
|
options.SetReusePort(true)
|
||||||
|
|
||||||
// local address is actually destination
|
|
||||||
handler(net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
|
handler(net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
|
||||||
|
|
||||||
|
ep.Close()
|
||||||
|
r.Complete(false)
|
||||||
}(r)
|
}(r)
|
||||||
})
|
})
|
||||||
stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
gstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||||
|
|
||||||
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) bool {
|
manager := &udpManager{
|
||||||
go func(r *udp.ForwarderRequest) {
|
stack: gstack,
|
||||||
var (
|
handler: handler,
|
||||||
wq waiter.Queue
|
m: make(map[string]*udpConn),
|
||||||
id = r.ID()
|
}
|
||||||
)
|
|
||||||
|
|
||||||
ep, err := r.CreateEndpoint(&wq)
|
|
||||||
if err != nil {
|
|
||||||
errors.LogError(context.Background(), err.String())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer ep.Close()
|
|
||||||
|
|
||||||
// prevents hanging connections and ensure timely release
|
|
||||||
ep.SocketOptions().SetLinger(tcpip.LingerOption{
|
|
||||||
Enabled: true,
|
|
||||||
Timeout: 15 * time.Second,
|
|
||||||
})
|
|
||||||
|
|
||||||
handler(net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewUDPConn(&wq, ep))
|
|
||||||
}(r)
|
|
||||||
|
|
||||||
|
gstack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||||
|
data := pkt.Clone().Data().AsRange().ToSlice()
|
||||||
|
// if len(data) == 0 {
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
|
||||||
|
dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
|
||||||
|
manager.feed(src, dst, data)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
out.tun, out.net = tun, n
|
out.tun, out.net = tun, n
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type udpManager struct {
|
||||||
|
stack *stack.Stack
|
||||||
|
handler func(dest net.Destination, conn net.Conn)
|
||||||
|
m map[string]*udpConn
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
uc, ok := m.m[src.NetAddr()]
|
||||||
|
if ok {
|
||||||
|
select {
|
||||||
|
case uc.ch <- data:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
uc, ok = m.m[src.NetAddr()]
|
||||||
|
if !ok {
|
||||||
|
uc = &udpConn{
|
||||||
|
ch: make(chan []byte, 1024),
|
||||||
|
src: src,
|
||||||
|
dst: dst,
|
||||||
|
}
|
||||||
|
uc.writeFunc = m.writeRawUDPPacket
|
||||||
|
uc.closeFunc = func() {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.close(uc)
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
m.m[src.NetAddr()] = uc
|
||||||
|
go m.handler(dst, uc)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case uc.ch <- data:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpManager) close(uc *udpConn) {
|
||||||
|
if !uc.closed {
|
||||||
|
uc.closed = true
|
||||||
|
close(uc.ch)
|
||||||
|
delete(m.m, uc.src.NetAddr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpManager) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error {
|
||||||
|
udpLen := header.UDPMinimumSize + len(payload)
|
||||||
|
srcIP := tcpip.AddrFromSlice(src.Address.IP())
|
||||||
|
dstIP := tcpip.AddrFromSlice(dst.Address.IP())
|
||||||
|
|
||||||
|
// build packet with appropriate IP header size
|
||||||
|
isIPv4 := dst.Address.Family().IsIPv4()
|
||||||
|
ipHdrSize := header.IPv6MinimumSize
|
||||||
|
ipProtocol := header.IPv6ProtocolNumber
|
||||||
|
if isIPv4 {
|
||||||
|
ipHdrSize = header.IPv4MinimumSize
|
||||||
|
ipProtocol = header.IPv4ProtocolNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
|
||||||
|
Payload: buffer.MakeWithData(payload),
|
||||||
|
})
|
||||||
|
defer pkt.DecRef()
|
||||||
|
|
||||||
|
// Build UDP header
|
||||||
|
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
||||||
|
udpHdr.Encode(&header.UDPFields{
|
||||||
|
SrcPort: uint16(src.Port),
|
||||||
|
DstPort: uint16(dst.Port),
|
||||||
|
Length: uint16(udpLen),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Calculate and set UDP checksum
|
||||||
|
xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
|
||||||
|
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
|
||||||
|
|
||||||
|
// Build IP header
|
||||||
|
if isIPv4 {
|
||||||
|
ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
|
||||||
|
ipHdr.Encode(&header.IPv4Fields{
|
||||||
|
TotalLength: uint16(header.IPv4MinimumSize + udpLen),
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: uint8(header.UDPProtocolNumber),
|
||||||
|
SrcAddr: srcIP,
|
||||||
|
DstAddr: dstIP,
|
||||||
|
})
|
||||||
|
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||||
|
} else {
|
||||||
|
ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
|
||||||
|
ipHdr.Encode(&header.IPv6Fields{
|
||||||
|
PayloadLength: uint16(udpLen),
|
||||||
|
TransportProtocol: header.UDPProtocolNumber,
|
||||||
|
HopLimit: 64,
|
||||||
|
SrcAddr: srcIP,
|
||||||
|
DstAddr: dstIP,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatch the packet
|
||||||
|
err := m.stack.WriteRawPacket(1, ipProtocol, buffer.MakeWithView(pkt.ToView()))
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("failed to write raw udp packet back to stack err ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpConn struct {
|
||||||
|
ch chan []byte
|
||||||
|
src net.Destination
|
||||||
|
dst net.Destination
|
||||||
|
writeFunc func(payload []byte, src net.Destination, dst net.Destination) error
|
||||||
|
closeFunc func()
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) Read(p []byte) (int, error) {
|
||||||
|
b, ok := <-c.ch
|
||||||
|
if !ok {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(p, b)
|
||||||
|
if n != len(b) {
|
||||||
|
return 0, io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||||
|
for i, b := range mb {
|
||||||
|
dst := c.dst
|
||||||
|
if b.UDP != nil {
|
||||||
|
dst = *b.UDP
|
||||||
|
}
|
||||||
|
err := c.writeFunc(b.Bytes(), dst, c.src)
|
||||||
|
if err != nil {
|
||||||
|
buf.ReleaseMulti(mb[i:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.Release()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) Write(p []byte) (int, error) {
|
||||||
|
err := c.writeFunc(p, c.dst, c.src)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) Close() error {
|
||||||
|
c.closeFunc()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) LocalAddr() net.Addr {
|
||||||
|
return c.src.RawNetAddr() // fake
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) RemoteAddr() net.Addr {
|
||||||
|
return c.src.RawNetAddr() // src
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) SetDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -100,32 +100,39 @@ func (m *udpSessionManagerServer) run() {
|
|||||||
func (m *udpSessionManagerServer) feed(id uint32, d []byte) {
|
func (m *udpSessionManagerServer) feed(id uint32, d []byte) {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
udpConn, ok := m.m[id]
|
udpConn, ok := m.m[id]
|
||||||
|
if ok {
|
||||||
|
select {
|
||||||
|
case udpConn.ch <- d:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
udpConn, ok = m.m[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
m.mutex.Lock()
|
udpConn = &InterUdpConn{
|
||||||
udpConn, ok = m.m[id]
|
conn: m.conn,
|
||||||
if !ok {
|
local: m.conn.LocalAddr(),
|
||||||
udpConn = &InterUdpConn{
|
remote: m.conn.RemoteAddr(),
|
||||||
conn: m.conn,
|
|
||||||
local: m.conn.LocalAddr(),
|
|
||||||
remote: m.conn.RemoteAddr(),
|
|
||||||
|
|
||||||
id: id,
|
id: id,
|
||||||
ch: make(chan []byte, udpMessageChanSize),
|
ch: make(chan []byte, udpMessageChanSize),
|
||||||
last: time.Now(),
|
last: time.Now(),
|
||||||
|
|
||||||
user: m.user,
|
user: m.user,
|
||||||
}
|
|
||||||
udpConn.closeFunc = func() {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
m.close(udpConn)
|
|
||||||
}
|
|
||||||
m.m[id] = udpConn
|
|
||||||
m.addConn(udpConn)
|
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
udpConn.closeFunc = func() {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.close(udpConn)
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
m.m[id] = udpConn
|
||||||
|
m.addConn(udpConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
Reference in New Issue
Block a user