mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-07-03 18:28:52 +00:00
WireGuard proxy: Refactor (#6287)
And https://github.com/XTLS/Xray-core/pull/6303#issuecomment-4669158076 Fixes https://github.com/XTLS/Xray-core/issues/6257
This commit is contained in:
+108
-118
@@ -1,4 +1,4 @@
|
||||
//go:build linux && !android
|
||||
//go:build linux
|
||||
|
||||
package wireguard
|
||||
|
||||
@@ -20,17 +20,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
type deviceNet struct {
|
||||
tunnel
|
||||
dialer *net.Dialer
|
||||
lc *net.ListenConfig
|
||||
|
||||
handle *netlink.Handle
|
||||
linkAddrs []netlink.Addr
|
||||
routes []*netlink.Route
|
||||
rules []*netlink.Rule
|
||||
}
|
||||
|
||||
var (
|
||||
tableIndex int = 10230
|
||||
mu sync.Mutex
|
||||
@@ -48,82 +37,18 @@ func allocateIPv6TableIndex() int {
|
||||
return currentIndex
|
||||
}
|
||||
|
||||
func newDeviceNet(interfaceName string) *deviceNet {
|
||||
dialer := &net.Dialer{}
|
||||
dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
if err := syscall.BindToDevice(int(fd), interfaceName); err != nil {
|
||||
errors.LogInfoInner(context.Background(), err, "failed to bind to device")
|
||||
}
|
||||
})
|
||||
}
|
||||
lc := &net.ListenConfig{}
|
||||
lc.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
if err := syscall.BindToDevice(int(fd), interfaceName); err != nil {
|
||||
errors.LogInfoInner(context.Background(), err, "failed to bind to device")
|
||||
}
|
||||
})
|
||||
}
|
||||
return &deviceNet{dialer: dialer, lc: lc}
|
||||
type kernelTun struct {
|
||||
tun.Device
|
||||
|
||||
dialer *net.Dialer
|
||||
lc *net.ListenConfig
|
||||
handle *netlink.Handle
|
||||
linkAddrs []netlink.Addr
|
||||
routes []*netlink.Route
|
||||
rules []*netlink.Rule
|
||||
}
|
||||
|
||||
func (d *deviceNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (
|
||||
net.Conn, error,
|
||||
) {
|
||||
return d.dialer.DialContext(ctx, "tcp", addr.String())
|
||||
}
|
||||
|
||||
func (d *deviceNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
|
||||
var conn net.PacketConn
|
||||
var err error
|
||||
if raddr.Addr().Is4() {
|
||||
conn, err = d.lc.ListenPacket(context.Background(), "udp", "0.0.0.0:0")
|
||||
} else {
|
||||
conn, err = d.lc.ListenPacket(context.Background(), "udp", "[::]:0")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &internet.PacketConnWrapper{
|
||||
PacketConn: conn,
|
||||
Dest: &net.UDPAddr{
|
||||
IP: raddr.Addr().AsSlice(),
|
||||
Port: int(raddr.Port()),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *deviceNet) Close() (err error) {
|
||||
var errs []error
|
||||
for _, rule := range d.rules {
|
||||
if err = d.handle.RuleDel(rule); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to delete rule: %w", err))
|
||||
}
|
||||
}
|
||||
for _, route := range d.routes {
|
||||
if err = d.handle.RouteDel(route); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to delete route: %w", err))
|
||||
}
|
||||
}
|
||||
if err = d.tunnel.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close tunnel: %w", err))
|
||||
}
|
||||
if d.handle != nil {
|
||||
d.handle.Close()
|
||||
d.handle = nil
|
||||
}
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return goerrors.Join(errs...)
|
||||
}
|
||||
|
||||
func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (t Tunnel, err error) {
|
||||
if handler != nil {
|
||||
return nil, errors.New("TODO: support promiscuous mode")
|
||||
}
|
||||
|
||||
func createKernelTun(localAddresses, dnsServers []netip.Addr, mtu int) (tdev tun.Device, tnet *Net, err error) {
|
||||
var v4, v6 *netip.Addr
|
||||
for _, prefixes := range localAddresses {
|
||||
if v4 == nil && prefixes.Is4() {
|
||||
@@ -150,22 +75,22 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
// system configs.
|
||||
if v4 != nil {
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv4/conf/all/rp_filter"); err != nil {
|
||||
return nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to disable ipv4 rp_filter for all: %w", err)
|
||||
}
|
||||
}
|
||||
if v6 != nil {
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/disable_ipv6"); err != nil {
|
||||
return nil, fmt.Errorf("failed to enable ipv6: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to enable ipv6: %w", err)
|
||||
}
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv6/conf/all/rp_filter"); err != nil {
|
||||
return nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to disable ipv6 rp_filter for all: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n := CalculateInterfaceName("wg")
|
||||
wgt, err := tun.CreateTUN(n, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -177,12 +102,12 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
// the operation require root privilege on container require '--privileged' flag.
|
||||
if v4 != nil {
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv4/conf/" + n + "/rp_filter"); err != nil {
|
||||
return nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to disable ipv4 rp_filter for tunnel: %w", err)
|
||||
}
|
||||
}
|
||||
if v6 != nil {
|
||||
if err = writeSysctlZero("/proc/sys/net/ipv6/conf/" + n + "/rp_filter"); err != nil {
|
||||
return nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to disable ipv6 rp_filter for tunnel: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,25 +121,28 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
}
|
||||
ipv6TableIndex--
|
||||
if ipv6TableIndex < 0 {
|
||||
return nil, fmt.Errorf("failed to find available ipv6 table index")
|
||||
return nil, nil, fmt.Errorf("failed to find available ipv6 table index")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out := newDeviceNet(n)
|
||||
out.handle, err = netlink.NewHandle()
|
||||
t := &kernelTun{
|
||||
Device: wgt,
|
||||
}
|
||||
|
||||
t.handle, err = netlink.NewHandle()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = out.Close()
|
||||
t.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
l, err := netlink.LinkByName(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if v4 != nil {
|
||||
@@ -224,7 +152,7 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
Mask: net.CIDRMask(v4.BitLen(), v4.BitLen()),
|
||||
},
|
||||
}
|
||||
out.linkAddrs = append(out.linkAddrs, addr)
|
||||
t.linkAddrs = append(t.linkAddrs, addr)
|
||||
}
|
||||
if v6 != nil {
|
||||
addr := netlink.Addr{
|
||||
@@ -233,7 +161,7 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
Mask: net.CIDRMask(v6.BitLen(), v6.BitLen()),
|
||||
},
|
||||
}
|
||||
out.linkAddrs = append(out.linkAddrs, addr)
|
||||
t.linkAddrs = append(t.linkAddrs, addr)
|
||||
|
||||
rt := &netlink.Route{
|
||||
LinkIndex: l.Attrs().Index,
|
||||
@@ -243,40 +171,102 @@ func createKernelTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
||||
},
|
||||
Table: ipv6TableIndex,
|
||||
}
|
||||
out.routes = append(out.routes, rt)
|
||||
t.routes = append(t.routes, rt)
|
||||
|
||||
r := netlink.NewRule()
|
||||
r.Table, r.Family, r.Src = ipv6TableIndex, unix.AF_INET6, addr.IPNet
|
||||
out.rules = append(out.rules, r)
|
||||
t.rules = append(t.rules, r)
|
||||
r = netlink.NewRule()
|
||||
r.Table, r.Family, r.OifName = ipv6TableIndex, unix.AF_INET6, n
|
||||
out.rules = append(out.rules, r)
|
||||
t.rules = append(t.rules, r)
|
||||
}
|
||||
|
||||
for _, addr := range out.linkAddrs {
|
||||
if err = out.handle.AddrAdd(l, &addr); err != nil {
|
||||
return nil, fmt.Errorf("failed to add address %s to %s: %w", addr, n, err)
|
||||
for _, addr := range t.linkAddrs {
|
||||
if err = t.handle.AddrAdd(l, &addr); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to add address %s to %s: %w", addr, n, err)
|
||||
}
|
||||
}
|
||||
if err = out.handle.LinkSetMTU(l, mtu); err != nil {
|
||||
return nil, err
|
||||
if err = t.handle.LinkSetMTU(l, mtu); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err = out.handle.LinkSetUp(l); err != nil {
|
||||
return nil, err
|
||||
if err = t.handle.LinkSetUp(l); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, route := range out.routes {
|
||||
if err = out.handle.RouteAdd(route); err != nil {
|
||||
return nil, fmt.Errorf("failed to add route %s: %w", route, err)
|
||||
for _, route := range t.routes {
|
||||
if err = t.handle.RouteAdd(route); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to add route %s: %w", route, err)
|
||||
}
|
||||
}
|
||||
for _, rule := range out.rules {
|
||||
if err = out.handle.RuleAdd(rule); err != nil {
|
||||
return nil, fmt.Errorf("failed to add rule %s: %w", rule, err)
|
||||
for _, rule := range t.rules {
|
||||
if err = t.handle.RuleAdd(rule); err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to add rule %s: %w", rule, err)
|
||||
}
|
||||
}
|
||||
out.tun = wgt
|
||||
return out, nil
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
if err := syscall.BindToDevice(int(fd), n); err != nil {
|
||||
errors.LogInfoInner(context.Background(), err, "failed to bind to device")
|
||||
}
|
||||
})
|
||||
}
|
||||
lc := &net.ListenConfig{}
|
||||
lc.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
if err := syscall.BindToDevice(int(fd), n); err != nil {
|
||||
errors.LogInfoInner(context.Background(), err, "failed to bind to device")
|
||||
}
|
||||
})
|
||||
}
|
||||
t.dialer = dialer
|
||||
t.lc = lc
|
||||
|
||||
tnet = &Net{
|
||||
DialContextTCPAddrPort: t.DialContextTCPAddrPort,
|
||||
DialUDPAddrPort: t.DialUDPAddrPort,
|
||||
dnsServers: dnsServers,
|
||||
hasV4: v4 != nil,
|
||||
hasV6: v6 != nil,
|
||||
}
|
||||
|
||||
return t, tnet, nil
|
||||
}
|
||||
|
||||
func (tun *kernelTun) Close() (err error) {
|
||||
var errs []error
|
||||
for _, rule := range tun.rules {
|
||||
if err = tun.handle.RuleDel(rule); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to delete rule: %w", err))
|
||||
}
|
||||
}
|
||||
for _, route := range tun.routes {
|
||||
if err = tun.handle.RouteDel(route); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to delete route: %w", err))
|
||||
}
|
||||
}
|
||||
if err = tun.Device.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close device: %w", err))
|
||||
}
|
||||
tun.handle.Close()
|
||||
errs = append(errs, tun.Device.Close())
|
||||
return goerrors.Join(errs...)
|
||||
}
|
||||
|
||||
func (tun *kernelTun) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error) {
|
||||
return tun.dialer.DialContext(ctx, "tcp", addr.String())
|
||||
}
|
||||
|
||||
func (tun *kernelTun) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) {
|
||||
conn, err := tun.lc.ListenPacket(context.Background(), "udp", ":0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &internet.PacketConnWrapper{
|
||||
PacketConn: conn,
|
||||
Dest: net.UDPAddrFromAddrPort(raddr),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func KernelTunSupported() (bool, error) {
|
||||
|
||||
Reference in New Issue
Block a user