mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-07-04 10:48:49 +00:00
WireGuard: Implement UDP FullCone NAT (#5833)
Fixes https://github.com/XTLS/Xray-core/issues/5601 --------- Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
This commit is contained in:
+218
-36
@@ -3,6 +3,7 @@ package wireguard
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -10,12 +11,17 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
@@ -138,7 +144,7 @@ func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, erro
|
||||
|
||||
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
|
||||
out := &gvisorNet{}
|
||||
tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
||||
tun, n, gstack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -147,60 +153,236 @@ func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
// handler is only used for promiscuous mode
|
||||
// capture all packets and send to handler
|
||||
|
||||
tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
||||
tcpForwarder := tcp.NewForwarder(gstack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
||||
go func(r *tcp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
id = r.ID()
|
||||
)
|
||||
var wq waiter.Queue
|
||||
var id = r.ID()
|
||||
|
||||
// Perform a TCP three-way handshake.
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
errors.LogError(context.Background(), err.String())
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
defer ep.Close()
|
||||
|
||||
// enable tcp keep-alive to prevent hanging connections
|
||||
ep.SocketOptions().SetKeepAlive(true)
|
||||
options := ep.SocketOptions()
|
||||
options.SetKeepAlive(false)
|
||||
options.SetReuseAddress(true)
|
||||
options.SetReusePort(true)
|
||||
|
||||
// local address is actually destination
|
||||
handler(net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
|
||||
|
||||
ep.Close()
|
||||
r.Complete(false)
|
||||
}(r)
|
||||
})
|
||||
stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
gstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
|
||||
udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) bool {
|
||||
go func(r *udp.ForwarderRequest) {
|
||||
var (
|
||||
wq waiter.Queue
|
||||
id = r.ID()
|
||||
)
|
||||
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
errors.LogError(context.Background(), err.String())
|
||||
return
|
||||
}
|
||||
defer ep.Close()
|
||||
|
||||
// prevents hanging connections and ensure timely release
|
||||
ep.SocketOptions().SetLinger(tcpip.LingerOption{
|
||||
Enabled: true,
|
||||
Timeout: 15 * time.Second,
|
||||
})
|
||||
|
||||
handler(net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewUDPConn(&wq, ep))
|
||||
}(r)
|
||||
manager := &udpManager{
|
||||
stack: gstack,
|
||||
handler: handler,
|
||||
m: make(map[string]*udpConn),
|
||||
}
|
||||
|
||||
gstack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
data := pkt.Clone().Data().AsRange().ToSlice()
|
||||
// if len(data) == 0 {
|
||||
// return false
|
||||
// }
|
||||
src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
|
||||
dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
|
||||
manager.feed(src, dst, data)
|
||||
return true
|
||||
})
|
||||
stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
}
|
||||
|
||||
out.tun, out.net = tun, n
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type udpManager struct {
|
||||
stack *stack.Stack
|
||||
handler func(dest net.Destination, conn net.Conn)
|
||||
m map[string]*udpConn
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) {
|
||||
m.mutex.RLock()
|
||||
uc, ok := m.m[src.NetAddr()]
|
||||
if ok {
|
||||
select {
|
||||
case uc.ch <- data:
|
||||
default:
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
uc, ok = m.m[src.NetAddr()]
|
||||
if !ok {
|
||||
uc = &udpConn{
|
||||
ch: make(chan []byte, 1024),
|
||||
src: src,
|
||||
dst: dst,
|
||||
}
|
||||
uc.writeFunc = m.writeRawUDPPacket
|
||||
uc.closeFunc = func() {
|
||||
m.mutex.Lock()
|
||||
m.close(uc)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
m.m[src.NetAddr()] = uc
|
||||
go m.handler(dst, uc)
|
||||
}
|
||||
|
||||
select {
|
||||
case uc.ch <- data:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpManager) close(uc *udpConn) {
|
||||
if !uc.closed {
|
||||
uc.closed = true
|
||||
close(uc.ch)
|
||||
delete(m.m, uc.src.NetAddr())
|
||||
}
|
||||
}
|
||||
|
||||
func (m *udpManager) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error {
|
||||
udpLen := header.UDPMinimumSize + len(payload)
|
||||
srcIP := tcpip.AddrFromSlice(src.Address.IP())
|
||||
dstIP := tcpip.AddrFromSlice(dst.Address.IP())
|
||||
|
||||
// build packet with appropriate IP header size
|
||||
isIPv4 := dst.Address.Family().IsIPv4()
|
||||
ipHdrSize := header.IPv6MinimumSize
|
||||
ipProtocol := header.IPv6ProtocolNumber
|
||||
if isIPv4 {
|
||||
ipHdrSize = header.IPv4MinimumSize
|
||||
ipProtocol = header.IPv4ProtocolNumber
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
// Build UDP header
|
||||
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: uint16(src.Port),
|
||||
DstPort: uint16(dst.Port),
|
||||
Length: uint16(udpLen),
|
||||
})
|
||||
|
||||
// Calculate and set UDP checksum
|
||||
xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
|
||||
|
||||
// Build IP header
|
||||
if isIPv4 {
|
||||
ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
|
||||
ipHdr.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + udpLen),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.UDPProtocolNumber),
|
||||
SrcAddr: srcIP,
|
||||
DstAddr: dstIP,
|
||||
})
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
} else {
|
||||
ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
|
||||
ipHdr.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(udpLen),
|
||||
TransportProtocol: header.UDPProtocolNumber,
|
||||
HopLimit: 64,
|
||||
SrcAddr: srcIP,
|
||||
DstAddr: dstIP,
|
||||
})
|
||||
}
|
||||
|
||||
// dispatch the packet
|
||||
err := m.stack.WriteRawPacket(1, ipProtocol, buffer.MakeWithView(pkt.ToView()))
|
||||
if err != nil {
|
||||
return errors.New("failed to write raw udp packet back to stack err ", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type udpConn struct {
|
||||
ch chan []byte
|
||||
src net.Destination
|
||||
dst net.Destination
|
||||
writeFunc func(payload []byte, src net.Destination, dst net.Destination) error
|
||||
closeFunc func()
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *udpConn) Read(p []byte) (int, error) {
|
||||
b, ok := <-c.ch
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, b)
|
||||
if n != len(b) {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
for i, b := range mb {
|
||||
dst := c.dst
|
||||
if b.UDP != nil {
|
||||
dst = *b.UDP
|
||||
}
|
||||
err := c.writeFunc(b.Bytes(), dst, c.src)
|
||||
if err != nil {
|
||||
buf.ReleaseMulti(mb[i:])
|
||||
return err
|
||||
}
|
||||
b.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) Write(p []byte) (int, error) {
|
||||
err := c.writeFunc(p, c.dst, c.src)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *udpConn) Close() error {
|
||||
c.closeFunc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) LocalAddr() net.Addr {
|
||||
return c.src.RawNetAddr() // fake
|
||||
}
|
||||
|
||||
func (c *udpConn) RemoteAddr() net.Addr {
|
||||
return c.src.RawNetAddr() // src
|
||||
}
|
||||
|
||||
func (c *udpConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user