mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-07-04 02:38:42 +00:00
WireGuard proxy: Refactor (#6287)
And https://github.com/XTLS/Xray-core/pull/6303#issuecomment-4669158076 Fixes https://github.com/XTLS/Xray-core/issues/6257
This commit is contained in:
+351
-250
@@ -1,148 +1,135 @@
|
||||
/*
|
||||
|
||||
Some of codes are copied from https://github.com/octeep/wireproxy, license below.
|
||||
|
||||
Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me>
|
||||
|
||||
Permission to use, copy, modify, and distribute this software for any
|
||||
purpose with or without fee is hereby granted, provided that the above
|
||||
copyright notice and this permission notice appear in all copies.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
*/
|
||||
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
gonet "net"
|
||||
"net/netip"
|
||||
reflect "reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/dice"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"github.com/xtls/xray-core/common/net/cnc"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/features/stats"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
// Handler is an outbound connection that silently swallow the entire payload.
|
||||
type Handler struct {
|
||||
conf *DeviceConfig
|
||||
net Tunnel
|
||||
bind *netBindClient
|
||||
policyManager policy.Manager
|
||||
dns dns.Client
|
||||
// cached configuration
|
||||
endpoints []netip.Addr
|
||||
hasIPv4, hasIPv6 bool
|
||||
wgLock sync.Mutex
|
||||
|
||||
streamSettings *internet.MemoryStreamConfig
|
||||
uplinkCounter stats.Counter
|
||||
downlinkCounter stats.Counter
|
||||
|
||||
tun tun.Device
|
||||
tnet *Net
|
||||
dev *device.Device
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new wireguard handler.
|
||||
func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
|
||||
func NewClient(ctx context.Context, conf *DeviceConfig) (*Handler, error) {
|
||||
v := core.MustFromContext(ctx)
|
||||
p := v.GetFeature(policy.ManagerType()).(policy.Manager)
|
||||
d := v.GetFeature(dns.ClientType()).(dns.Client)
|
||||
|
||||
endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
|
||||
streamSettings := session.StreamSettingsFromContext(ctx).(*internet.MemoryStreamConfig)
|
||||
tag := session.FullHandlerFromContext(ctx).Tag()
|
||||
var uplinkCounter stats.Counter
|
||||
var downlinkCounter stats.Counter
|
||||
if len(tag) > 0 && p.ForSystem().Stats.OutboundUplink {
|
||||
statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
|
||||
name := "outbound>>>" + tag + ">>>traffic>>>uplink"
|
||||
c, _ := stats.GetOrRegisterCounter(statsManager, name)
|
||||
if c != nil {
|
||||
uplinkCounter = c
|
||||
}
|
||||
}
|
||||
if len(tag) > 0 && p.ForSystem().Stats.OutboundDownlink {
|
||||
statsManager := v.GetFeature(stats.ManagerType()).(stats.Manager)
|
||||
name := "outbound>>>" + tag + ">>>traffic>>>downlink"
|
||||
c, _ := stats.GetOrRegisterCounter(statsManager, name)
|
||||
if c != nil {
|
||||
downlinkCounter = c
|
||||
}
|
||||
}
|
||||
|
||||
if len(conf.Peers) == 0 {
|
||||
return nil, errors.New("empty peers")
|
||||
}
|
||||
for _, peer := range conf.Peers {
|
||||
if peer.PublicKey == "" {
|
||||
return nil, errors.New("peer without publickey")
|
||||
}
|
||||
if peer.Endpoint == "" {
|
||||
return nil, errors.New("peer without endpoint")
|
||||
}
|
||||
}
|
||||
|
||||
localAddresses := make([]netip.Addr, 0, len(conf.Endpoint))
|
||||
for _, localaddress := range conf.Endpoint {
|
||||
addr, err := netip.ParseAddr(localaddress)
|
||||
if err == nil {
|
||||
localAddresses = append(localAddresses, addr)
|
||||
continue
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(localaddress)
|
||||
if err == nil {
|
||||
localAddresses = append(localAddresses, prefix.Addr())
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
kernelTunSupported, err := KernelTunSupported()
|
||||
if err != nil {
|
||||
errors.LogWarningInner(context.Background(), err, "Failed to check kernel TUN support")
|
||||
}
|
||||
var tun tun.Device
|
||||
var tnet *Net
|
||||
if !conf.NoKernelTun && kernelTunSupported {
|
||||
errors.LogWarning(context.Background(), "Using kernel TUN")
|
||||
tun, tnet, err = createKernelTun(localAddresses, []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1.0.0.1"), netip.MustParseAddr("2606:4700:4700::1111"), netip.MustParseAddr("2606:4700:4700::1001")}, int(conf.Mtu))
|
||||
} else {
|
||||
errors.LogWarning(context.Background(), "Using gVisor TUN")
|
||||
tun, tnet, _, err = CreateNetTUN(localAddresses, []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1.0.0.1"), netip.MustParseAddr("2606:4700:4700::1111"), netip.MustParseAddr("2606:4700:4700::1001")}, int(conf.Mtu), true)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := v.GetFeature(dns.ClientType()).(dns.Client)
|
||||
return &Handler{
|
||||
conf: conf,
|
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||
policyManager: p,
|
||||
dns: d,
|
||||
endpoints: endpoints,
|
||||
hasIPv4: hasIPv4,
|
||||
hasIPv6: hasIPv6,
|
||||
|
||||
streamSettings: streamSettings,
|
||||
uplinkCounter: uplinkCounter,
|
||||
downlinkCounter: downlinkCounter,
|
||||
|
||||
tun: tun,
|
||||
tnet: tnet,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) Close() (err error) {
|
||||
go func() {
|
||||
h.wgLock.Lock()
|
||||
defer h.wgLock.Unlock()
|
||||
|
||||
if h.net != nil {
|
||||
_ = h.net.Close()
|
||||
h.net = nil
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer) (err error) {
|
||||
h.wgLock.Lock()
|
||||
defer h.wgLock.Unlock()
|
||||
|
||||
if h.bind != nil && h.bind.dialer == dialer && h.net != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Info,
|
||||
Content: "switching dialer",
|
||||
})
|
||||
|
||||
if h.net != nil {
|
||||
_ = h.net.Close()
|
||||
h.net = nil
|
||||
}
|
||||
if h.bind != nil {
|
||||
_ = h.bind.Close()
|
||||
h.bind = nil
|
||||
}
|
||||
|
||||
// bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer
|
||||
h.bind = &netBindClient{
|
||||
netBind: netBind{
|
||||
dns: h.dns,
|
||||
dnsOption: dns.IPOption{
|
||||
IPv4Enable: h.hasIPv4,
|
||||
IPv6Enable: h.hasIPv6,
|
||||
},
|
||||
workers: int(h.conf.NumWorkers),
|
||||
readQueue: make(chan *netReadInfo),
|
||||
},
|
||||
ctx: ctx,
|
||||
dialer: dialer,
|
||||
reserved: h.conf.Reserved,
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
h.bind.Close()
|
||||
h.bind = nil
|
||||
}
|
||||
}()
|
||||
|
||||
h.net, err = h.makeVirtualTun()
|
||||
if err != nil {
|
||||
return errors.New("failed to create virtual tun interface").Base(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Process implements OutboundHandler.Dispatch().
|
||||
// Process implements proxy.Outbound.Process.
|
||||
func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds)-1]
|
||||
@@ -152,40 +139,31 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
||||
ob.Name = "wireguard"
|
||||
ob.CanSpliceCopy = 3
|
||||
|
||||
if err := h.processWireGuard(ctx, dialer); err != nil {
|
||||
if h.dev == nil {
|
||||
if err := h.init(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.dev.Up(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Destination of the inner request.
|
||||
destination := ob.Target
|
||||
command := protocol.RequestCommandTCP
|
||||
if destination.Network == net.Network_UDP {
|
||||
command = protocol.RequestCommandUDP
|
||||
var addr netip.Addr
|
||||
if ob.Target.Address.Family().IsDomain() {
|
||||
ip, err := h.resolveRemote(ob.Target.Address.String())
|
||||
if err != nil {
|
||||
return errors.New("failed to resolve domain").Base(err)
|
||||
}
|
||||
addr, _ = netip.AddrFromSlice(ip)
|
||||
} else {
|
||||
addr, _ = netip.AddrFromSlice(ob.Target.Address.IP())
|
||||
}
|
||||
|
||||
// resolve dns
|
||||
addr := destination.Address
|
||||
if addr.Family().IsDomain() {
|
||||
ips, _, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
|
||||
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
|
||||
})
|
||||
{ // Resolve fallback
|
||||
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
|
||||
ips, _, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
|
||||
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
|
||||
})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return errors.New("failed to lookup DNS").Base(err)
|
||||
} else if len(ips) == 0 {
|
||||
return dns.ErrEmptyResponse
|
||||
}
|
||||
addr = net.IPAddress(ips[dice.Roll(len(ips))])
|
||||
addrPort := netip.AddrPortFrom(addr, ob.Target.Port.Value())
|
||||
if !addrPort.IsValid() {
|
||||
return errors.New("invalid target ", ob.Target)
|
||||
}
|
||||
destination.Address = addr
|
||||
|
||||
var newCtx context.Context
|
||||
var newCancel context.CancelFunc
|
||||
@@ -193,59 +171,64 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
||||
newCtx, newCancel = context.WithCancel(context.Background())
|
||||
}
|
||||
|
||||
p := h.policyManager.ForLevel(0)
|
||||
|
||||
sessionPolicy := h.policyManager.ForLevel(0)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
timer := signal.CancelAfterInactivity(ctx, func() {
|
||||
cancel()
|
||||
if newCancel != nil {
|
||||
newCancel()
|
||||
}
|
||||
}, p.Timeouts.ConnectionIdle)
|
||||
addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value())
|
||||
}, sessionPolicy.Timeouts.ConnectionIdle)
|
||||
|
||||
var requestFunc func() error
|
||||
var responseFunc func() error
|
||||
if newCtx != nil {
|
||||
ctx = newCtx
|
||||
}
|
||||
|
||||
if command == protocol.RequestCommandTCP {
|
||||
conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort)
|
||||
var reader buf.Reader
|
||||
var writer buf.Writer
|
||||
|
||||
switch ob.Target.Network {
|
||||
case net.Network_TCP:
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if sessionPolicy.Timeouts.Handshake != 0 {
|
||||
timeoutCtx, timeoutCancel := context.WithTimeout(ctx, sessionPolicy.Timeouts.Handshake)
|
||||
conn, err = h.tnet.DialContextTCPAddrPort(timeoutCtx, addrPort)
|
||||
timeoutCancel()
|
||||
} else {
|
||||
conn, err = h.tnet.DialContextTCPAddrPort(ctx, addrPort)
|
||||
}
|
||||
if err != nil {
|
||||
return errors.New("failed to create TCP connection").Base(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
requestFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||
}
|
||||
responseFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
} else if command == protocol.RequestCommandUDP {
|
||||
conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort)
|
||||
reader = buf.NewReader(conn)
|
||||
writer = buf.NewWriter(conn)
|
||||
case net.Network_UDP:
|
||||
conn, err := h.tnet.DialUDPAddrPort(netip.AddrPort{}, addrPort)
|
||||
if err != nil {
|
||||
return errors.New("failed to create UDP connection").Base(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn = &udpConnClient{
|
||||
Conn: conn,
|
||||
dest: destination,
|
||||
}
|
||||
|
||||
requestFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.DownlinkOnly)
|
||||
return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
|
||||
}
|
||||
responseFunc = func() error {
|
||||
defer timer.SetTimeout(p.Timeouts.UplinkOnly)
|
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||
c := &udpConnClient{
|
||||
PacketConn: conn.(*internet.PacketConnWrapper).PacketConn,
|
||||
resolveFunc: h.resolveRemote,
|
||||
dest: gonet.UDPAddrFromAddrPort(addrPort),
|
||||
}
|
||||
reader = c
|
||||
writer = c
|
||||
default:
|
||||
panic(ob.Target.Network)
|
||||
}
|
||||
|
||||
if newCtx != nil {
|
||||
ctx = newCtx
|
||||
requestFunc := func() error {
|
||||
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
|
||||
return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
|
||||
responseFunc := func() error {
|
||||
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
|
||||
return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
|
||||
responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer))
|
||||
@@ -258,108 +241,191 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
|
||||
return nil
|
||||
}
|
||||
|
||||
// creates a tun interface on netstack given a configuration
|
||||
func (h *Handler) makeVirtualTun() (Tunnel, error) {
|
||||
t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil)
|
||||
func (h *Handler) Close() (err error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.dev != nil {
|
||||
h.dev.Close()
|
||||
h.dev = nil
|
||||
h.tun = nil
|
||||
} else if h.tun != nil {
|
||||
h.tun.Close()
|
||||
h.tun = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) init(ctx context.Context) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.dev != nil {
|
||||
return nil
|
||||
}
|
||||
resolveFunc := h.resolveLocal
|
||||
listenFunc := func() (net.PacketConn, error) {
|
||||
dest, err := net.ParseDestination("udp:" + h.conf.Peers[0].Endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := internet.DialSystem(ctx, dest, h.streamSettings.SocketSettings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pktConn net.PacketConn
|
||||
switch c := conn.(type) {
|
||||
case *internet.PacketConnWrapper:
|
||||
pktConn = c.PacketConn
|
||||
case *cnc.Connection:
|
||||
pktConn = &internet.FakePacketConn{Conn: c}
|
||||
default:
|
||||
panic(reflect.TypeOf(c))
|
||||
}
|
||||
if h.streamSettings.UdpmaskManager != nil {
|
||||
newConn, err := h.streamSettings.UdpmaskManager.WrapPacketConnClient(pktConn)
|
||||
if err != nil {
|
||||
pktConn.Close()
|
||||
return nil, errors.New("mask err").Base(err)
|
||||
}
|
||||
pktConn = newConn
|
||||
}
|
||||
if h.uplinkCounter != nil || h.downlinkCounter != nil {
|
||||
pktConn = &PacketCounterConnection{
|
||||
PacketConn: pktConn,
|
||||
ReadCounter: h.downlinkCounter,
|
||||
WriteCounter: h.uplinkCounter,
|
||||
}
|
||||
}
|
||||
return pktConn, nil
|
||||
}
|
||||
bind := &bind{}
|
||||
logger := &device.Logger{
|
||||
Verbosef: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Debug,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
Errorf: func(format string, args ...any) {
|
||||
log.Record(&log.GeneralMessage{
|
||||
Severity: log.Severity_Error,
|
||||
Content: fmt.Sprintf(format, args...),
|
||||
})
|
||||
},
|
||||
}
|
||||
dev := device.NewDevice(h.tun, bind, logger)
|
||||
bind.resolveFunc = resolveFunc
|
||||
bind.listenFunc = listenFunc
|
||||
bind.downFunc = dev.Down
|
||||
bind.reserved = h.conf.Reserved
|
||||
var cfg strings.Builder
|
||||
cfg.WriteString("private_key=" + h.conf.SecretKey + "\n")
|
||||
for _, peer := range h.conf.Peers {
|
||||
cfg.WriteString("public_key=" + peer.PublicKey + "\n")
|
||||
if peer.PreSharedKey != "" {
|
||||
cfg.WriteString("preshared_key=" + peer.PreSharedKey + "\n")
|
||||
}
|
||||
cfg.WriteString("endpoint=" + peer.Endpoint + "\n")
|
||||
for _, ip := range peer.AllowedIps {
|
||||
cfg.WriteString("allowed_ip=" + ip + "\n")
|
||||
}
|
||||
if peer.KeepAlive != "" {
|
||||
cfg.WriteString("persistent_keepalive_interval=" + peer.KeepAlive + "\n")
|
||||
}
|
||||
}
|
||||
err := dev.IpcSet(cfg.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = dev.Up()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.dev = dev
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) resolveLocal(host string) (net.IP, error) {
|
||||
return resolveDomain(host, h.conf.DomainStrategy, func(host string) ([]net.IP, error) {
|
||||
ips, _, err := h.dns.LookupIP(host, dns.IPOption{IPv4Enable: true, IPv6Enable: true})
|
||||
return ips, err
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) resolveRemote(host string) (net.IP, error) {
|
||||
return resolveDomain(host, h.conf.DomainStrategy, func(host string) ([]net.IP, error) {
|
||||
addrs, err := h.tnet.LookupHost(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ips := make([]net.IP, 0, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
ips = append(ips, net.ParseIP(addr))
|
||||
}
|
||||
return ips, nil
|
||||
})
|
||||
}
|
||||
|
||||
func resolveDomain(host string, strategy DeviceConfig_DomainStrategy, lookupIP func(host string) ([]net.IP, error)) (net.IP, error) {
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return ip, nil
|
||||
}
|
||||
ips, err := lookupIP(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.bind.dnsOption.IPv4Enable = h.hasIPv4
|
||||
h.bind.dnsOption.IPv6Enable = h.hasIPv6
|
||||
|
||||
if err = t.BuildDevice(h.createIPCRequest(), h.bind); err != nil {
|
||||
_ = t.Close()
|
||||
return nil, err
|
||||
if len(ips) == 0 {
|
||||
return nil, dns.ErrEmptyResponse
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// serialize the config into an IPC request
|
||||
func (h *Handler) createIPCRequest() string {
|
||||
var request strings.Builder
|
||||
|
||||
request.WriteString(fmt.Sprintf("private_key=%s\n", h.conf.SecretKey))
|
||||
|
||||
if !h.conf.IsClient {
|
||||
// placeholder, we'll handle actual port listening on Xray
|
||||
request.WriteString("listen_port=1337\n")
|
||||
}
|
||||
|
||||
for _, peer := range h.conf.Peers {
|
||||
if peer.PublicKey != "" {
|
||||
request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey))
|
||||
}
|
||||
|
||||
if peer.PreSharedKey != "" {
|
||||
request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey))
|
||||
}
|
||||
|
||||
address, port, err := net.SplitHostPort(peer.Endpoint)
|
||||
if err != nil {
|
||||
errors.LogError(h.bind.ctx, "failed to split endpoint ", peer.Endpoint, " into address and port")
|
||||
}
|
||||
addr := net.ParseAddress(address)
|
||||
if addr.Family().IsDomain() {
|
||||
dialerIp := h.bind.dialer.DestIpAddress()
|
||||
if dialerIp != nil {
|
||||
addr = net.ParseAddress(dialerIp.String())
|
||||
errors.LogInfo(h.bind.ctx, "createIPCRequest use dialer dest ip: ", addr)
|
||||
} else {
|
||||
ips, _, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.conf.preferIP4(),
|
||||
IPv6Enable: h.conf.preferIP6(),
|
||||
})
|
||||
{ // Resolve fallback
|
||||
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
|
||||
ips, _, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{
|
||||
IPv4Enable: h.conf.fallbackIP4(),
|
||||
IPv6Enable: h.conf.fallbackIP6(),
|
||||
})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
errors.LogInfoInner(h.bind.ctx, err, "createIPCRequest failed to lookup DNS")
|
||||
} else if len(ips) == 0 {
|
||||
errors.LogInfo(h.bind.ctx, "createIPCRequest empty lookup DNS")
|
||||
} else {
|
||||
addr = net.IPAddress(ips[dice.Roll(len(ips))])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if peer.Endpoint != "" {
|
||||
request.WriteString(fmt.Sprintf("endpoint=%s:%s\n", addr, port))
|
||||
}
|
||||
|
||||
for _, ip := range peer.AllowedIps {
|
||||
request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip))
|
||||
}
|
||||
|
||||
if peer.KeepAlive != 0 {
|
||||
request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive))
|
||||
var got4, got6 []net.IP
|
||||
for _, ip := range ips {
|
||||
if ip.To4() != nil {
|
||||
got4 = append(got4, ip)
|
||||
} else {
|
||||
got6 = append(got6, ip)
|
||||
}
|
||||
}
|
||||
|
||||
return request.String()[:request.Len()]
|
||||
var got []net.IP
|
||||
switch strategy {
|
||||
case DeviceConfig_FORCE_IP:
|
||||
got = ips
|
||||
return ips[dice.Roll(len(ips))], nil
|
||||
case DeviceConfig_FORCE_IP4:
|
||||
got = got4
|
||||
case DeviceConfig_FORCE_IP6:
|
||||
got = got6
|
||||
case DeviceConfig_FORCE_IP46:
|
||||
got = got4
|
||||
if len(got) == 0 {
|
||||
got = got6
|
||||
}
|
||||
case DeviceConfig_FORCE_IP64:
|
||||
got = got6
|
||||
if len(got) == 0 {
|
||||
got = got4
|
||||
}
|
||||
default:
|
||||
panic(strategy)
|
||||
}
|
||||
if len(got) == 0 {
|
||||
return nil, dns.ErrEmptyResponse
|
||||
}
|
||||
return got[dice.Roll(len(got))], nil
|
||||
}
|
||||
|
||||
type udpConnClient struct {
|
||||
net.Conn
|
||||
dest net.Destination
|
||||
net.PacketConn
|
||||
resolveFunc func(host string) (net.IP, error)
|
||||
dest *net.UDPAddr
|
||||
}
|
||||
|
||||
func (c *udpConnClient) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
b := buf.New()
|
||||
b.Resize(0, buf.Size)
|
||||
n, addr, err := c.Conn.(net.PacketConn).ReadFrom(b.Bytes())
|
||||
n, addr, err := c.PacketConn.ReadFrom(b.Bytes())
|
||||
if err != nil {
|
||||
b.Release()
|
||||
return nil, err
|
||||
}
|
||||
if addr == nil { // should never hit
|
||||
addr = c.dest.RawNetAddr()
|
||||
}
|
||||
b.Resize(0, int32(n))
|
||||
|
||||
b.UDP = &net.Destination{
|
||||
@@ -375,9 +441,22 @@ func (c *udpConnClient) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
for i, b := range mb {
|
||||
dst := c.dest
|
||||
if b.UDP != nil {
|
||||
dst = *b.UDP
|
||||
if b.UDP.Address.Family().IsDomain() {
|
||||
ip, err := c.resolveFunc(b.UDP.Address.String())
|
||||
if err != nil {
|
||||
errors.LogErrorInner(context.Background(), err, "drop packet to ", b.UDP, " with size ", len(b.Bytes()))
|
||||
b.Release()
|
||||
continue
|
||||
}
|
||||
dst = &net.UDPAddr{
|
||||
IP: ip,
|
||||
Port: int(b.UDP.Port),
|
||||
}
|
||||
} else {
|
||||
dst = b.UDP.RawNetAddr().(*net.UDPAddr)
|
||||
}
|
||||
}
|
||||
_, err := c.Conn.(net.PacketConn).WriteTo(b.Bytes(), dst.RawNetAddr())
|
||||
_, err := c.PacketConn.WriteTo(b.Bytes(), dst)
|
||||
if err != nil {
|
||||
buf.ReleaseMulti(mb[i:])
|
||||
return err
|
||||
@@ -386,3 +465,25 @@ func (c *udpConnClient) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type PacketCounterConnection struct {
|
||||
net.PacketConn
|
||||
ReadCounter stats.Counter
|
||||
WriteCounter stats.Counter
|
||||
}
|
||||
|
||||
func (c *PacketCounterConnection) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.PacketConn.ReadFrom(p)
|
||||
if err == nil && c.ReadCounter != nil {
|
||||
c.ReadCounter.Add(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *PacketCounterConnection) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
n, err = c.PacketConn.WriteTo(p, addr)
|
||||
if err == nil && c.WriteCounter != nil {
|
||||
c.WriteCounter.Add(int64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user