From 36303694d1edfeb1b83bf6e0f79caebf4651c063 Mon Sep 17 00:00:00 2001 From: LjhAUMEM Date: Fri, 29 May 2026 16:29:53 +0800 Subject: [PATCH] Finalmask: Add Realm (UDP hole punching in Hysteria v2.9.1) (#6137) https://github.com/XTLS/Xray-core/pull/5657#issuecomment-4446406536 https://github.com/XTLS/Xray-core/pull/6137#issuecomment-4469822775 Example: https://github.com/XTLS/Xray-core/pull/6137#issue-4454013510 --- .github/workflows/release.yml | 3 + README.md | 7 + go.mod | 5 + go.sum | 10 + infra/conf/transport_internet.go | 88 ++++ transport/internet/finalmask/realm/client.go | 171 ++++++++ transport/internet/finalmask/realm/config.go | 27 ++ .../internet/finalmask/realm/config.pb.go | 181 ++++++++ .../internet/finalmask/realm/config.proto | 19 + transport/internet/finalmask/realm/http.go | 307 ++++++++++++++ transport/internet/finalmask/realm/punch.go | 145 +++++++ transport/internet/finalmask/realm/server.go | 401 ++++++++++++++++++ transport/internet/finalmask/realm/stun.go | 224 ++++++++++ .../congestion/bbr/bandwidth_sampler.go | 6 +- .../hysteria/congestion/bbr/bbr_sender.go | 3 +- 15 files changed, 1594 insertions(+), 3 deletions(-) create mode 100644 transport/internet/finalmask/realm/client.go create mode 100644 transport/internet/finalmask/realm/config.go create mode 100644 transport/internet/finalmask/realm/config.pb.go create mode 100644 transport/internet/finalmask/realm/config.proto create mode 100644 transport/internet/finalmask/realm/http.go create mode 100644 transport/internet/finalmask/realm/punch.go create mode 100644 transport/internet/finalmask/realm/server.go create mode 100644 transport/internet/finalmask/realm/stun.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 589eddb1..b1b46875 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -208,6 +208,9 @@ jobs: go build -o build_assets/xray.exe -trimpath -buildvcs=false -gcflags="all=-l=4" -ldflags="-X github.com/xtls/xray-core/core.build=${COMMID} -s -w -buildid=" -v ./main # The line below is for without running conhost.exe version. Commented for not being used. Provided for reference. # go build -o build_assets/wxray.exe -trimpath -buildvcs=false -gcflags="all=-l=4" -ldflags="-H windowsgui -X github.com/xtls/xray-core/core.build=${COMMID} -s -w -buildid=" -v ./main + elif [[ ${GOOS} == 'android' ]]; then + echo 'Building Xray for Android...' + go build -o build_assets/xray -trimpath -buildvcs=false -gcflags="all=-l=4" -ldflags="-X github.com/xtls/xray-core/core.build=${COMMID} -s -w -buildid= -checklinkname=0" -v ./main else echo 'Building Xray...' if [[ ${GOARCH} == 'mips' || ${GOARCH} == 'mipsle' ]]; then diff --git a/README.md b/README.md index cad3ce83..deeb2904 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,13 @@ Make sure that you are using the same Go version, and remember to set the git co CGO_ENABLED=0 go build -o xray -trimpath -buildvcs=false -gcflags="all=-l=4" -ldflags="-X github.com/xtls/xray-core/core.build=REPLACE -s -w -buildid=" -v ./main ``` +For Android: + +```bash +GOOS=android GOARCH=arm64 CGO_ENABLED=1 CC=/path/to/aarch64-linux-android24-clang go build -o xray -trimpath -buildvcs=false -gcflags="all=-l=4" -ldflags="-X github.com/xtls/xray-core/core.build=REPLACE -s -w -buildid= -checklinkname=0" -v ./main +GOOS=android GOARCH=amd64 CGO_ENABLED=1 CC=/path/to/x86_64-linux-android24-clang go build -o xray -trimpath -buildvcs=false -gcflags="all=-l=4" -ldflags="-X github.com/xtls/xray-core/core.build=REPLACE -s -w -buildid= -checklinkname=0" -v ./main +``` + If you are compiling a 32-bit MIPS/MIPSLE target, use this command instead: ```bash diff --git a/go.mod b/go.mod index e4c4b8e2..c5b21f68 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/klauspost/cpuid/v2 v2.3.0 github.com/miekg/dns v1.1.72 github.com/pelletier/go-toml v1.9.5 + github.com/pion/stun/v3 v3.1.2 github.com/pires/go-proxyproto v0.12.0 github.com/refraction-networking/utls v1.8.3-0.20260301010127-aa6edf4b11af github.com/robfig/cron/v3 v3.0.1 @@ -43,9 +44,13 @@ require ( github.com/juju/ratelimit v1.0.2 // indirect github.com/klauspost/compress v1.17.4 // indirect github.com/kr/text v0.2.0 // indirect + github.com/pion/dtls/v3 v3.1.2 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/transport/v4 v4.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/vishvananda/netns v0.0.5 // indirect + github.com/wlynxg/anet v0.0.5 // indirect golang.org/x/mod v0.35.0 // indirect golang.org/x/text v0.37.0 // indirect golang.org/x/time v0.12.0 // indirect diff --git a/go.sum b/go.sum index 1e0d65bf..317452b5 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,14 @@ github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3v github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 h1:JhzVVoYvbOACxoUmOs6V/G4D5nPVUW73rKvXxP4XUJc= github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= +github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc= +github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/stun/v3 v3.1.2 h1:86IhD8wFn6IDW4b1/0QzoQS+f5PeA8OHHRn8UZW5ErY= +github.com/pion/stun/v3 v3.1.2/go.mod h1:H7gDic7nNwlUL05pbs6T1dtaBehh/KjupxfWw3ZI7cA= +github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= +github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= github.com/pires/go-proxyproto v0.12.0 h1:TTCxD66dU898tahivkqc3hoceZp7P44FnorWyo9d5vM= github.com/pires/go-proxyproto v0.12.0/go.mod h1:qUvfqUMEoX7T8g0q7TQLDnhMjdTrxnG0hvpMn+7ePNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -67,6 +75,8 @@ github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= +github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/xtls/reality v0.0.0-20260322125925-9234c772ba8f h1:iy2JRioxmUpoJ3SzbFPyTxHZMbR/rSHP7dOOgYaq1O8= github.com/xtls/reality v0.0.0-20260322125925-9234c772ba8f/go.mod h1:DsJblcWDGt76+FVqBVwbwRhxyyNJsGV48gJLch0OOWI= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 0fd1c767..02f94de0 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -32,6 +32,7 @@ import ( "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/aes128gcm" "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original" "github.com/xtls/xray-core/transport/internet/finalmask/noise" + "github.com/xtls/xray-core/transport/internet/finalmask/realm" "github.com/xtls/xray-core/transport/internet/finalmask/salamander" finalsudoku "github.com/xtls/xray-core/transport/internet/finalmask/sudoku" "github.com/xtls/xray-core/transport/internet/finalmask/xdns" @@ -1255,6 +1256,7 @@ var ( "sudoku": func() interface{} { return new(Sudoku) }, "xdns": func() interface{} { return new(Xdns) }, "xicmp": func() interface{} { return new(Xicmp) }, + "realm": func() interface{} { return new(Realm) }, }, "type", "settings") ) @@ -1908,6 +1910,92 @@ func (c *Xicmp) Build() (proto.Message, error) { return config, nil } +type Realm struct { + Url string `json:"url"` + StunServers []string `json:"stunServers"` + TlsConfig *TLSConfig `json:"tlsConfig"` +} + +func (c *Realm) Build() (proto.Message, error) { + var scheme, host, port, token, id string + var stunServers []string + var tlsConfig *tls.Config + + u, err := url.Parse(c.Url) + if err != nil { + return nil, err + } + + switch u.Scheme { + case "realm": + scheme = "https" + case "realm+http": + scheme = "http" + default: + return nil, errors.New("invalid scheme", u.Scheme) + } + + host = u.Hostname() + if host == "" { + return nil, errors.New("invalid host", host) + } + + port = u.Port() + if port == "" { + port = "443" + if scheme == "http" { + port = "80" + } + } + + token, err = url.PathUnescape(u.User.String()) + if err != nil { + return nil, err + } + if token == "" { + return nil, errors.New("invalid token", token) + } + + id, err = url.PathUnescape(strings.TrimPrefix(u.EscapedPath(), "/")) + if err != nil { + return nil, err + } + if id == "" { + return nil, errors.New("invalid id", id) + } + + if len(c.StunServers) == 0 { + return nil, errors.New("empty stunServers") + } + + for _, s := range c.StunServers { + _, _, err = net.SplitHostPort(s) + if err != nil { + return nil, err + } + } + + stunServers = c.StunServers + + if c.TlsConfig != nil { + tc, err := c.TlsConfig.Build() + if err != nil { + return nil, err + } + tlsConfig = tc.(*tls.Config) + } + + return &realm.Config{ + Scheme: scheme, + Host: host, + Port: port, + Token: token, + ID: id, + StunServers: stunServers, + TlsConfig: tlsConfig, + }, nil +} + type Mask struct { Type string `json:"type"` Settings *json.RawMessage `json:"settings"` diff --git a/transport/internet/finalmask/realm/client.go b/transport/internet/finalmask/realm/client.go new file mode 100644 index 00000000..ab7ba6bb --- /dev/null +++ b/transport/internet/finalmask/realm/client.go @@ -0,0 +1,171 @@ +package realm + +import ( + "context" + goerrors "errors" + "net" + "net/netip" + "slices" + "strings" + "time" + + "github.com/pion/stun/v3" + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/errors" +) + +type realmConnClient struct { + net.PacketConn + peer *net.UDPAddr + + realmClient *Client + realmID string + stunServers []string + stunTimeout time.Duration + punchTimeout time.Duration + punchInterval time.Duration +} + +func NewConnClient(config *Config, raw net.PacketConn) (net.PacketConn, error) { + conn := &realmConnClient{ + PacketConn: raw, + + realmClient: NewClient(config.Scheme, config.Host, config.Port, config.Token, config.TlsConfig), + realmID: config.ID, + stunServers: config.StunServers, + stunTimeout: defaultSTUNTimeout, + punchTimeout: defaultPunchTimeout, + punchInterval: defaultPunchInterval, + } + + return conn.getpeer() +} + +func (c *realmConnClient) getpeer() (net.PacketConn, error) { + start := time.Now() + servers := resolveSTUNServers(c.PacketConn.LocalAddr().(*net.UDPAddr).IP, c.stunServers) + errors.LogDebug(context.Background(), "[realm] update stun servers ", servers, " with ", time.Since(start)) + if len(servers) == 0 { + return nil, errors.New("empty locals") + } + + start = time.Now() + locals := c.discover(servers) + errors.LogDebug(context.Background(), "[realm] update stun locals ", locals, " with ", time.Since(start)) + if len(locals) == 0 { + return nil, errors.New("empty locals") + } + + meta := common.Must2(NewPunchMetadata()) + + start = time.Now() + resp, err := c.realmClient.Connect(context.Background(), c.realmID, ConnectRequest{ + Addresses: addrPortStrings(locals), + PunchMetadata: meta, + }) + if err != nil { + return nil, err + } + errors.LogDebug(context.Background(), "[realm] ", c.realmID, " ", meta.Nonce, " connect ", resp.Addresses, " with ", time.Since(start)) + + peers, _ := parseAddrPorts(resp.Addresses) + errors.LogDebug(context.Background(), "[realm] update peers ", peers) + filteredPeers, seen := candidatePunchAddrs(locals, peers) + errors.LogDebug(context.Background(), "[realm] filtered peers ", filteredPeers) + expandedPeers := expandSymmetricNATCandidates(filteredPeers, seen) + errors.LogDebug(context.Background(), "[realm] expanded peers ", expandedPeers) + + if len(expandedPeers) == 0 { + return nil, errors.New("empty peers") + } + + start = time.Now() + peer, err := c.punch(meta, peers) + if err != nil { + return nil, errors.New("punch fail").Base(err) + } + errors.LogDebug(context.Background(), "[realm] punch peer ", peer, " with ", time.Since(start)) + + c.peer = peer + return c, nil +} + +func (c *realmConnClient) discover(servers []*net.UDPAddr) []netip.AddrPort { + var transactionIDs = make(map[[stun.TransactionIDSize]byte]struct{}, len(servers)) + for _, server := range servers { + msg := common.Must2(stun.Build(stun.TransactionID, stun.BindingRequest)) + transactionIDs[msg.TransactionID] = struct{}{} + _, _ = c.PacketConn.WriteTo(msg.Raw, server) + } + + var buf = make([]byte, 1500) + var results = make([]netip.AddrPort, 0, len(servers)) + c.PacketConn.SetReadDeadline(time.Now().Add(defaultSTUNTimeout)) + for len(transactionIDs) > 0 { + n, _, err := c.PacketConn.ReadFrom(buf) + if err != nil { + break + } + msg, addrPort, err := parseSTUNBindingResponse(buf[:n]) + if err != nil { + continue + } + if _, ok := transactionIDs[msg.TransactionID]; ok { + delete(transactionIDs, msg.TransactionID) + results = append(results, addrPort) + } + } + c.PacketConn.SetReadDeadline(time.Time{}) + slices.SortFunc(results, func(a, b netip.AddrPort) int { + return strings.Compare(a.String(), b.String()) + }) + + return results +} + +func (c *realmConnClient) punch(meta PunchMetadata, peers []netip.AddrPort) (*net.UDPAddr, error) { + defer c.PacketConn.SetReadDeadline(time.Time{}) + nextSend := time.Now() + deadline := nextSend.Add(c.punchTimeout) + buf := make([]byte, punchMaxWireLen) + for { + now := time.Now() + if now.After(deadline) { + return nil, errors.New("timeout") + } + if now.After(nextSend) { + for _, peer := range peers { + packet := common.Must2(EncodePunchPacket(PunchPacketHello, meta)) + _, _ = c.PacketConn.WriteTo(packet, net.UDPAddrFromAddrPort(peer)) + } + nextSend = now.Add(c.punchInterval) + } + + if nextSend.After(deadline) { + c.PacketConn.SetReadDeadline(deadline) + } else { + c.PacketConn.SetReadDeadline(nextSend) + } + n, addr, err := c.PacketConn.ReadFrom(buf) + if err != nil { + var netErr net.Error + if goerrors.As(err, &netErr) && netErr.Timeout() { + continue + } + return nil, err + } + packet, err := DecodePunchPacket(buf[:n], meta) + if err != nil { + continue + } + if packet.Type == PunchPacketHello { + packet := common.Must2(EncodePunchPacket(PunchPacketAck, meta)) + _, _ = c.PacketConn.WriteTo(packet, addr) + } + return addr.(*net.UDPAddr), nil + } +} + +func (c *realmConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return c.PacketConn.WriteTo(p, c.peer) +} diff --git a/transport/internet/finalmask/realm/config.go b/transport/internet/finalmask/realm/config.go new file mode 100644 index 00000000..2604fa8a --- /dev/null +++ b/transport/internet/finalmask/realm/config.go @@ -0,0 +1,27 @@ +package realm + +import ( + "net" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/hysteria/udphop" +) + +func (c *Config) UDP() {} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + _, ok1 := raw.(*internet.FakePacketConn) + _, ok2 := raw.(*udphop.UdpHopPacketConn) + if level != 0 || ok1 || ok2 { + return nil, errors.New("realm requires being at the outermost level") + } + return NewConnClient(c, raw) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + if level != 0 { + return nil, errors.New("realm requires being at the outermost level") + } + return NewConnServer(c, raw) +} diff --git a/transport/internet/finalmask/realm/config.pb.go b/transport/internet/finalmask/realm/config.pb.go new file mode 100644 index 00000000..ee273e73 --- /dev/null +++ b/transport/internet/finalmask/realm/config.pb.go @@ -0,0 +1,181 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.5 +// source: transport/internet/finalmask/realm/config.proto + +package realm + +import ( + tls "github.com/xtls/xray-core/transport/internet/tls" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + Scheme string `protobuf:"bytes,1,opt,name=scheme,proto3" json:"scheme,omitempty"` + Host string `protobuf:"bytes,2,opt,name=host,proto3" json:"host,omitempty"` + Port string `protobuf:"bytes,3,opt,name=port,proto3" json:"port,omitempty"` + Token string `protobuf:"bytes,4,opt,name=token,proto3" json:"token,omitempty"` + ID string `protobuf:"bytes,5,opt,name=ID,proto3" json:"ID,omitempty"` + StunServers []string `protobuf:"bytes,6,rep,name=stun_servers,json=stunServers,proto3" json:"stun_servers,omitempty"` + TlsConfig *tls.Config `protobuf:"bytes,7,opt,name=tls_config,json=tlsConfig,proto3" json:"tls_config,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_realm_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_realm_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_realm_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetScheme() string { + if x != nil { + return x.Scheme + } + return "" +} + +func (x *Config) GetHost() string { + if x != nil { + return x.Host + } + return "" +} + +func (x *Config) GetPort() string { + if x != nil { + return x.Port + } + return "" +} + +func (x *Config) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +func (x *Config) GetID() string { + if x != nil { + return x.ID + } + return "" +} + +func (x *Config) GetStunServers() []string { + if x != nil { + return x.StunServers + } + return nil +} + +func (x *Config) GetTlsConfig() *tls.Config { + if x != nil { + return x.TlsConfig + } + return nil +} + +var File_transport_internet_finalmask_realm_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_realm_config_proto_rawDesc = "" + + "\n" + + "/transport/internet/finalmask/realm/config.proto\x12'xray.transport.internet.finalmask.realm\x1a#transport/internet/tls/config.proto\"\xd5\x01\n" + + "\x06Config\x12\x16\n" + + "\x06scheme\x18\x01 \x01(\tR\x06scheme\x12\x12\n" + + "\x04host\x18\x02 \x01(\tR\x04host\x12\x12\n" + + "\x04port\x18\x03 \x01(\tR\x04port\x12\x14\n" + + "\x05token\x18\x04 \x01(\tR\x05token\x12\x0e\n" + + "\x02ID\x18\x05 \x01(\tR\x02ID\x12!\n" + + "\fstun_servers\x18\x06 \x03(\tR\vstunServers\x12B\n" + + "\n" + + "tls_config\x18\a \x01(\v2#.xray.transport.internet.tls.ConfigR\ttlsConfigB\x97\x01\n" + + "+com.xray.transport.internet.finalmask.realmP\x01Z xray.transport.internet.tls.Config + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_realm_config_proto_init() } +func file_transport_internet_finalmask_realm_config_proto_init() { + if File_transport_internet_finalmask_realm_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_realm_config_proto_rawDesc), len(file_transport_internet_finalmask_realm_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_realm_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_realm_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_realm_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_realm_config_proto = out.File + file_transport_internet_finalmask_realm_config_proto_goTypes = nil + file_transport_internet_finalmask_realm_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/realm/config.proto b/transport/internet/finalmask/realm/config.proto new file mode 100644 index 00000000..62dce86e --- /dev/null +++ b/transport/internet/finalmask/realm/config.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.realm; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Realm"; +option go_package = "github.com/xtls/xray-core/transport/internet/finalmask/realm"; +option java_package = "com.xray.transport.internet.finalmask.realm"; +option java_multiple_files = true; + +import "transport/internet/tls/config.proto"; + +message Config { + string scheme = 1; + string host = 2; + string port = 3; + string token = 4; + string ID = 5; + repeated string stun_servers = 6; + xray.transport.internet.tls.Config tls_config = 7; +} \ No newline at end of file diff --git a/transport/internet/finalmask/realm/http.go b/transport/internet/finalmask/realm/http.go new file mode 100644 index 00000000..2f58ee89 --- /dev/null +++ b/transport/internet/finalmask/realm/http.go @@ -0,0 +1,307 @@ +package realm + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + + "github.com/xtls/xray-core/transport/internet/tls" +) + +const maxErrorBodySize = 64 * 1024 + +const ( + PunchNonceSize = 16 + PunchObfsKeySize = 32 +) + +type Client struct { + scheme string + hostport string + token string + httpClient *http.Client +} + +type RegisterResponse struct { + SessionID string `json:"session_id"` + TTL int `json:"ttl"` +} + +type HeartbeatResponse struct { + TTL int `json:"ttl"` +} + +type HeartbeatRequest struct { + Addresses []string `json:"addresses,omitempty"` +} + +type PunchMetadata struct { + Nonce string `json:"nonce"` + Obfs string `json:"obfs"` +} + +type ConnectRequest struct { + Addresses []string `json:"addresses"` + PunchMetadata +} + +type ConnectResponse struct { + Addresses []string `json:"addresses"` + PunchMetadata +} + +type PunchEvent struct { + Addresses []string `json:"addresses"` + PunchMetadata +} + +type ErrorResponse struct { + Error string `json:"error"` + Message string `json:"message"` +} + +type StatusError struct { + StatusCode int + Response ErrorResponse +} + +func (e *StatusError) Error() string { + if e.Response.Error != "" || e.Response.Message != "" { + return fmt.Sprintf("realm server returned %d: %s: %s", e.StatusCode, e.Response.Error, e.Response.Message) + } + return fmt.Sprintf("realm server returned %d", e.StatusCode) +} + +func NewClient(scheme, host, port, token string, tlsConfig *tls.Config) *Client { + client := http.DefaultClient + if tlsConfig != nil { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = tlsConfig.GetTLSConfig() + client = &http.Client{Transport: tr} + } + return &Client{ + scheme: scheme, + hostport: net.JoinHostPort(host, port), + token: token, + httpClient: client, + } +} + +func NewPunchMetadata() (PunchMetadata, error) { + nonce, err := randHex(PunchNonceSize) + if err != nil { + return PunchMetadata{}, err + } + obfs, err := randHex(PunchObfsKeySize) + if err != nil { + return PunchMetadata{}, err + } + return PunchMetadata{ + Nonce: nonce, + Obfs: obfs, + }, nil +} + +func (c *Client) Register(ctx context.Context, realmID string, addresses []string) (*RegisterResponse, error) { + var resp RegisterResponse + if err := c.doJSON(ctx, http.MethodPost, realmID, "", c.token, addressRequest{Addresses: addresses}, http.StatusOK, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +func (c *Client) Deregister(ctx context.Context, realmID, sessionID string) error { + return c.doJSON(ctx, http.MethodDelete, realmID, "", sessionID, nil, http.StatusNoContent, nil) +} + +func (c *Client) Heartbeat(ctx context.Context, realmID, sessionID string, req HeartbeatRequest) (*HeartbeatResponse, error) { + var resp HeartbeatResponse + if err := c.doJSON(ctx, http.MethodPost, realmID, "heartbeat", sessionID, req, http.StatusOK, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +func (c *Client) Connect(ctx context.Context, realmID string, req ConnectRequest) (*ConnectResponse, error) { + var resp ConnectResponse + if err := c.doJSON(ctx, http.MethodPost, realmID, "connect", c.token, req, http.StatusOK, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +type ConnectResponseRequest struct { + Addresses []string `json:"addresses"` +} + +func (c *Client) ConnectResponse(ctx context.Context, realmID, sessionID, nonce string, addresses []string) error { + subPath := "connects/" + url.PathEscape(nonce) + return c.doJSON(ctx, http.MethodPost, realmID, subPath, sessionID, + ConnectResponseRequest{Addresses: addresses}, http.StatusNoContent, nil) +} + +func (c *Client) Events(ctx context.Context, realmID, sessionID string) (*EventStream, error) { + req, err := c.newRequest(ctx, http.MethodGet, realmID, "events", sessionID, nil) + if err != nil { + return nil, err + } + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + return nil, decodeStatusError(resp) + } + return newEventStream(resp), nil +} + +type addressRequest struct { + Addresses []string `json:"addresses"` +} + +func (c *Client) doJSON(ctx context.Context, method, realmID, subPath, token string, in any, expectedStatus int, out any) error { + var body io.Reader + if in != nil { + bs, err := json.Marshal(in) + if err != nil { + return err + } + body = bytes.NewReader(bs) + } + req, err := c.newRequest(ctx, method, realmID, subPath, token, body) + if err != nil { + return err + } + if in != nil { + req.Header.Set("Content-Type", "application/json") + } + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != expectedStatus { + return decodeStatusError(resp) + } + if out == nil || resp.Body == nil { + return nil + } + return json.NewDecoder(resp.Body).Decode(out) +} + +func (c *Client) newRequest(ctx context.Context, method, realmID, subPath, token string, body io.Reader) (*http.Request, error) { + u := &url.URL{ + Scheme: c.scheme, + Host: c.hostport, + Path: joinURLPath("v1", url.PathEscape(realmID), subPath), + } + req, err := http.NewRequestWithContext(ctx, method, u.String(), body) + if err != nil { + return nil, err + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + return req, nil +} + +func randHex(size int) (string, error) { + b := make([]byte, size) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +func joinURLPath(parts ...string) string { + var joined []string + for _, part := range parts { + part = strings.Trim(part, "/") + if part == "" { + continue + } + joined = append(joined, part) + } + return "/" + strings.Join(joined, "/") +} + +func decodeStatusError(resp *http.Response) error { + var errResp ErrorResponse + _ = json.NewDecoder(io.LimitReader(resp.Body, maxErrorBodySize)).Decode(&errResp) + return &StatusError{ + StatusCode: resp.StatusCode, + Response: errResp, + } +} + +type EventStream struct { + resp *http.Response + scanner *bufio.Scanner +} + +func newEventStream(resp *http.Response) *EventStream { + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 1024), 1024*1024) + return &EventStream{ + resp: resp, + scanner: scanner, + } +} + +func (s *EventStream) Close() error { + return s.resp.Body.Close() +} + +func (s *EventStream) Next() (*PunchEvent, error) { + var eventName string + var data strings.Builder + for s.scanner.Scan() { + line := s.scanner.Text() + if line == "" { + if eventName == "" && data.Len() == 0 { + continue + } + if eventName != "punch" { + eventName = "" + data.Reset() + continue + } + var ev PunchEvent + if err := json.Unmarshal([]byte(data.String()), &ev); err != nil { + return nil, err + } + return &ev, nil + } + if strings.HasPrefix(line, ":") { + continue + } + field, value, ok := strings.Cut(line, ":") + if !ok { + continue + } + value = strings.TrimPrefix(value, " ") + switch field { + case "event": + eventName = value + case "data": + if data.Len() > 0 { + data.WriteByte('\n') + } + data.WriteString(value) + } + } + if err := s.scanner.Err(); err != nil { + return nil, err + } + return nil, io.EOF +} diff --git a/transport/internet/finalmask/realm/punch.go b/transport/internet/finalmask/realm/punch.go new file mode 100644 index 00000000..92174f6b --- /dev/null +++ b/transport/internet/finalmask/realm/punch.go @@ -0,0 +1,145 @@ +package realm + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "math/big" +) + +const ( + MaxPunchPadding = 1024 + + punchSaltLen = 8 + // Plain punch payload before obfs: + // 8-byte magic, 1-byte type, 16-byte nonce, then 0..1024 random padding bytes. + punchHeaderLen = 25 + punchMinWireLen = punchSaltLen + punchHeaderLen + punchMaxWireLen = punchMinWireLen + MaxPunchPadding +) + +var ( + ErrInvalidPunchPacket = errors.New("invalid punch packet") + + punchMagic = [8]byte{'H', 'Y', 'R', 'L', 'M', 'v', '1', 0} +) + +type PunchPacketType byte + +const ( + PunchPacketHello PunchPacketType = 0x01 + PunchPacketAck PunchPacketType = 0x02 +) + +type PunchPacket struct { + Type PunchPacketType + PaddingLength int +} + +func EncodePunchPacket(packetType PunchPacketType, meta PunchMetadata) ([]byte, error) { + if !validPunchPacketType(packetType) { + return nil, fmt.Errorf("%w: unknown packet type", ErrInvalidPunchPacket) + } + nonce, obfsKey, err := decodePunchMetadata(meta) + if err != nil { + return nil, err + } + paddingLength, err := randomPaddingLength() + if err != nil { + return nil, err + } + plain := make([]byte, punchHeaderLen+paddingLength) + copy(plain[:len(punchMagic)], punchMagic[:]) + plain[len(punchMagic)] = byte(packetType) + copy(plain[len(punchMagic)+1:punchHeaderLen], nonce) + if paddingLength > 0 { + if _, err := rand.Read(plain[punchHeaderLen:]); err != nil { + return nil, err + } + } + packet := make([]byte, punchSaltLen+len(plain)) + if _, err := rand.Read(packet[:punchSaltLen]); err != nil { + return nil, err + } + copy(packet[punchSaltLen:], plain) + xorPunchPacket(packet[punchSaltLen:], obfsKey, packet[:punchSaltLen]) + return packet, nil +} + +func DecodePunchPacket(packet []byte, meta PunchMetadata) (PunchPacket, error) { + if len(packet) < punchMinWireLen { + return PunchPacket{}, fmt.Errorf("%w: packet too short", ErrInvalidPunchPacket) + } + if len(packet) > punchMaxWireLen { + return PunchPacket{}, fmt.Errorf("%w: packet too long", ErrInvalidPunchPacket) + } + nonce, obfsKey, err := decodePunchMetadata(meta) + if err != nil { + return PunchPacket{}, err + } + salt := packet[:punchSaltLen] + plain := append([]byte(nil), packet[punchSaltLen:]...) + xorPunchPacket(plain, obfsKey, salt) + if !bytes.Equal(plain[:len(punchMagic)], punchMagic[:]) { + return PunchPacket{}, fmt.Errorf("%w: bad magic", ErrInvalidPunchPacket) + } + packetType := PunchPacketType(plain[len(punchMagic)]) + if !validPunchPacketType(packetType) { + return PunchPacket{}, fmt.Errorf("%w: unknown packet type", ErrInvalidPunchPacket) + } + if !bytes.Equal(plain[len(punchMagic)+1:punchHeaderLen], nonce) { + return PunchPacket{}, fmt.Errorf("%w: nonce mismatch", ErrInvalidPunchPacket) + } + return PunchPacket{ + Type: packetType, + PaddingLength: len(plain) - punchHeaderLen, + }, nil +} + +func decodePunchMetadata(meta PunchMetadata) (nonce, obfsKey []byte, err error) { + nonce, err = decodeHexSize("nonce", meta.Nonce, PunchNonceSize) + if err != nil { + return nil, nil, err + } + obfsKey, err = decodeHexSize("obfs", meta.Obfs, PunchObfsKeySize) + if err != nil { + return nil, nil, err + } + return nonce, obfsKey, nil +} + +func decodeHexSize(name, value string, size int) ([]byte, error) { + b, err := hex.DecodeString(value) + if err != nil { + return nil, fmt.Errorf("%w: invalid %s", ErrInvalidPunchPacket, name) + } + if len(b) != size { + return nil, fmt.Errorf("%w: invalid %s length", ErrInvalidPunchPacket, name) + } + return b, nil +} + +func randomPaddingLength() (int, error) { + n, err := rand.Int(rand.Reader, big.NewInt(MaxPunchPadding+1)) + if err != nil { + return 0, err + } + return int(n.Int64()), nil +} + +func xorPunchPacket(packet, obfsKey, salt []byte) { + h := sha256.New() + _, _ = h.Write(obfsKey) + _, _ = h.Write(salt) + mask := h.Sum(nil) + for i := range packet { + packet[i] ^= mask[i%len(mask)] + } +} + +func validPunchPacketType(packetType PunchPacketType) bool { + return packetType == PunchPacketHello || packetType == PunchPacketAck +} diff --git a/transport/internet/finalmask/realm/server.go b/transport/internet/finalmask/realm/server.go new file mode 100644 index 00000000..544bf49d --- /dev/null +++ b/transport/internet/finalmask/realm/server.go @@ -0,0 +1,401 @@ +package realm + +import ( + "context" + go_errors "errors" + "net" + "net/http" + "net/netip" + "slices" + "strings" + "sync" + "time" + + "github.com/pion/stun/v3" + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/errors" +) + +const defaultEventBuffer = 16 +const defaultStunCacheTTL = time.Second * 10 +const defaultHeartbeatInterval = time.Second * 15 + +type PunchPacketEvent struct { + Addr netip.AddrPort + Packet PunchPacket +} + +type STUNPacketEvent struct { + Message *stun.Message + Addr netip.AddrPort +} + +type realmConnServer struct { + cleaned chan struct{} + ctx context.Context + cancel context.CancelFunc + net.PacketConn + + realmClient *Client + realmID string + stunServers []string + stunTimeout time.Duration + punchTimeout time.Duration + punchInterval time.Duration + + events map[PunchMetadata]chan PunchPacketEvent + stun chan STUNPacketEvent + mu sync.Mutex + + locals []netip.AddrPort + localsMu sync.Mutex + localsLast time.Time +} + +func NewConnServer(config *Config, raw net.PacketConn) (net.PacketConn, error) { + ctx, cancel := context.WithCancel(context.Background()) + + conn := &realmConnServer{ + cleaned: make(chan struct{}), + ctx: ctx, + cancel: cancel, + PacketConn: raw, + + realmClient: NewClient(config.Scheme, config.Host, config.Port, config.Token, config.TlsConfig), + realmID: config.ID, + stunServers: config.StunServers, + stunTimeout: defaultSTUNTimeout, + punchTimeout: defaultPunchTimeout, + punchInterval: defaultPunchInterval, + + events: make(map[PunchMetadata]chan PunchPacketEvent), + stun: make(chan STUNPacketEvent, defaultEventBuffer), + } + + go conn.run() + + return conn, nil +} + +func (c *realmConnServer) addSTUN(packet []byte) bool { + if !stun.IsMessage(packet) { + return false + } + msg, addr, err := parseSTUNBindingResponse(packet) + if err != nil { + return false + } + select { + case c.stun <- STUNPacketEvent{Message: msg, Addr: addr}: + default: + } + return true +} + +func (c *realmConnServer) addPunch(packet []byte, addr net.Addr) bool { + c.mu.Lock() + defer c.mu.Unlock() + for meta, ch := range c.events { + punchPacket, err := DecodePunchPacket(packet, meta) + if err != nil { + continue + } + select { + case ch <- PunchPacketEvent{ + Addr: addr.(*net.UDPAddr).AddrPort(), + Packet: punchPacket, + }: + default: + } + return true + } + return false +} + +func (c *realmConnServer) waitctx(ctx context.Context, t time.Duration) bool { + timer := time.NewTimer(t) + defer timer.Stop() + select { + case <-timer.C: + return false + case <-ctx.Done(): + return true + } +} + +func (c *realmConnServer) discover(servers []*net.UDPAddr) []netip.AddrPort { + var transactionIDs = make(map[[stun.TransactionIDSize]byte]struct{}, len(servers)) + for _, server := range servers { + msg := common.Must2(stun.Build(stun.TransactionID, stun.BindingRequest)) + transactionIDs[msg.TransactionID] = struct{}{} + _, _ = c.PacketConn.WriteTo(msg.Raw, server) + } + + var deadline = time.NewTimer(c.stunTimeout) + var results = make([]netip.AddrPort, 0, len(servers)) + for len(transactionIDs) > 0 { + select { + case <-deadline.C: + goto end + case ev := <-c.stun: + if _, ok := transactionIDs[ev.Message.TransactionID]; ok { + delete(transactionIDs, ev.Message.TransactionID) + results = append(results, ev.Addr) + } + } + } +end: + deadline.Stop() + slices.SortFunc(results, func(a, b netip.AddrPort) int { + return strings.Compare(a.String(), b.String()) + }) + + return results +} + +func (c *realmConnServer) getlocals(force bool) []netip.AddrPort { + c.localsMu.Lock() + if force || time.Since(c.localsLast) > defaultStunCacheTTL { + start := time.Now() + servers := resolveSTUNServers(c.PacketConn.LocalAddr().(*net.UDPAddr).IP, c.stunServers) + errors.LogDebug(context.Background(), "[realm] update stun servers ", servers, " with ", time.Since(start)) + if len(servers) > 0 { + start = time.Now() + locals := c.discover(servers) + errors.LogDebug(context.Background(), "[realm] update stun locals ", locals, " with ", time.Since(start)) + if len(locals) > 0 { + c.locals = locals + c.localsLast = time.Now() + } + } + } + locals := append([]netip.AddrPort(nil), c.locals...) + c.localsMu.Unlock() + return locals +} + +func (c *realmConnServer) punch(ctx context.Context, meta PunchMetadata, peers []netip.AddrPort) { + c.mu.Lock() + if _, ok := c.events[meta]; ok { + c.mu.Unlock() + return + } + ch := make(chan PunchPacketEvent, defaultEventBuffer) + c.events[meta] = ch + c.mu.Unlock() + + start := time.Now() + for _, peer := range peers { + packet := common.Must2(EncodePunchPacket(PunchPacketHello, meta)) + _, _ = c.PacketConn.WriteTo(packet, net.UDPAddrFromAddrPort(peer)) + } + deadline := time.NewTimer(c.punchTimeout) + ticker := time.NewTicker(c.punchInterval) + for { + select { + case <-ctx.Done(): + errors.LogDebug(context.Background(), "[realm] punch ", meta.Nonce, " FAIL > session end") + goto end + case <-deadline.C: + errors.LogDebug(context.Background(), "[realm] punch ", meta.Nonce, " FAIL > timeout") + goto end + case <-ticker.C: + for _, peer := range peers { + packet := common.Must2(EncodePunchPacket(PunchPacketHello, meta)) + _, _ = c.PacketConn.WriteTo(packet, net.UDPAddrFromAddrPort(peer)) + } + case event := <-ch: + if event.Packet.Type == PunchPacketHello { + packet := common.Must2(EncodePunchPacket(PunchPacketAck, meta)) + _, _ = c.PacketConn.WriteTo(packet, net.UDPAddrFromAddrPort(event.Addr)) + } + errors.LogDebug(context.Background(), "[realm] punch ", meta.Nonce, " SUCCESS ", event.Addr, " with ", time.Since(start)) + goto end + } + } +end: + deadline.Stop() + ticker.Stop() + + c.mu.Lock() + delete(c.events, meta) + close(ch) + c.mu.Unlock() +} + +func (c *realmConnServer) run() { + backoff := time.Second +retry: + resp, err := c.realmClient.Register(c.ctx, c.realmID, addrPortStrings(c.getlocals(false))) + if err != nil { + errors.LogErrorInner(context.Background(), err, "[realm] ", c.realmID, " register session err retry in ", backoff) + if c.waitctx(c.ctx, backoff) { + close(c.cleaned) + return + } + backoff *= 2 + if backoff > 30*time.Second { + backoff = 30 * time.Second + } + goto retry + } + backoff = time.Second + errors.LogDebug(context.Background(), "[realm] ", c.realmID, " sesssion ", resp.SessionID, " ", resp.TTL, " registered") + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 2) + go c.heartbeatLoop(ctx, resp.SessionID, resp.TTL, errCh) + go c.eventsLoop(ctx, resp.SessionID, resp.TTL, errCh) + select { + case <-c.ctx.Done(): + case err = <-errCh: + } + cancel() + errors.LogDebugInner(context.Background(), err, "[realm] session ", resp.SessionID, " end") + + select { + case <-c.ctx.Done(): + _ = c.realmClient.Deregister(context.Background(), c.realmID, resp.SessionID) + errors.LogDebug(context.Background(), "[realm] ", c.realmID, " ", resp.SessionID, " deregistered") + close(c.cleaned) + return + default: + goto retry + } +} + +func (c *realmConnServer) heartbeatLoop(ctx context.Context, sid string, ttl int, errCh chan<- error) { + interval := defaultHeartbeatInterval + if ttl > 0 { + interval = time.Second * time.Duration(ttl) / 2 + } + + last := time.Now() + cur := c.getlocals(false) + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + errCh <- nil + return + case <-ticker.C: + req := HeartbeatRequest{} + if new := c.getlocals(false); !slices.Equal(cur, new) { + cur = new + req.Addresses = addrPortStrings(cur) + } + start := time.Now() + resp, err := c.realmClient.Heartbeat(ctx, c.realmID, sid, req) + if err != nil { + var statusErr *StatusError + if go_errors.As(err, &statusErr) && (statusErr.StatusCode == http.StatusUnauthorized || statusErr.StatusCode == http.StatusNotFound) { + errCh <- errors.New("session invalid") + return + } + if time.Since(last) > time.Second*time.Duration(ttl) { + errCh <- errors.New("session lost") + return + } + continue + } + last = start + errors.LogDebug(context.Background(), "[realm] heartbeat ", resp.TTL, " with ", time.Since(start)) + if resp.TTL > 0 && resp.TTL != ttl { + ttl = resp.TTL + ticker.Reset(time.Second * time.Duration(ttl) / 2) + } + } + } +} + +func (c *realmConnServer) eventsLoop(ctx context.Context, sid string, ttl int, errCh chan<- error) { + backoff := time.Second + last := time.Now() + for { + start := time.Now() + stream, err := c.realmClient.Events(ctx, c.realmID, sid) + if err != nil { + var statusErr *StatusError + if go_errors.As(err, &statusErr) && (statusErr.StatusCode == http.StatusUnauthorized || statusErr.StatusCode == http.StatusNotFound) { + errCh <- errors.New("session invalid") + return + } + if time.Since(last) > time.Second*time.Duration(ttl) { + errCh <- errors.New("session lost") + return + } + errors.LogDebugInner(context.Background(), err, "[realm] ", sid, " open stream err retry in ", backoff) + if c.waitctx(ctx, backoff) { + errCh <- nil + return + } + backoff *= 2 + if backoff > 30*time.Second { + backoff = 30 * time.Second + } + continue + } + backoff = time.Second + last = start + errors.LogDebug(context.Background(), "[realm] open stream with ", time.Since(start)) + for { + ev, err := stream.Next() + if err != nil { + _ = stream.Close() + break + } + last = time.Now() + go c.punchEvent(ctx, sid, ev) + } + } +} + +func (c *realmConnServer) punchEvent(ctx context.Context, sid string, ev *PunchEvent) { + errors.LogDebug(context.Background(), "[realm] start punch event ", ev.Nonce, " ", ev.Addresses) + + locals := c.getlocals(false) + + peers, _ := parseAddrPorts(ev.Addresses) + errors.LogDebug(context.Background(), "[realm] ", ev.Nonce, " update peers ", peers) + filteredPeers, seen := candidatePunchAddrs(locals, peers) + errors.LogDebug(context.Background(), "[realm] ", ev.Nonce, " filtered peers ", filteredPeers) + expandedPeers := expandSymmetricNATCandidates(filteredPeers, seen) + errors.LogDebug(context.Background(), "[realm] ", ev.Nonce, " expanded peers ", expandedPeers) + + if len(expandedPeers) == 0 { + errors.LogDebug(context.Background(), "[realm] punch ", ev.Nonce, " FAIL > empty peers") + return + } + + start := time.Now() + err := c.realmClient.ConnectResponse(ctx, c.realmID, sid, ev.Nonce, addrPortStrings(locals)) + if err != nil { + errors.LogDebugInner(context.Background(), err, "[realm] ", ev.Nonce, " connect response err") + } + errors.LogDebug(context.Background(), "[realm] ", ev.Nonce, " connect response ", locals, " with ", time.Since(start)) + + c.punch(ctx, ev.PunchMetadata, expandedPeers) +} + +func (c *realmConnServer) ReadFrom(p []byte) (int, net.Addr, error) { + for { + n, addr, err := c.PacketConn.ReadFrom(p) + if err != nil { + return n, addr, err + } + if c.addSTUN(p[:n]) { + continue + } + if c.addPunch(p[:n], addr) { + continue + } + return n, addr, nil + } +} + +func (c *realmConnServer) Close() error { + c.cancel() + <-c.cleaned + return c.PacketConn.Close() +} diff --git a/transport/internet/finalmask/realm/stun.go b/transport/internet/finalmask/realm/stun.go new file mode 100644 index 00000000..5ff59bb3 --- /dev/null +++ b/transport/internet/finalmask/realm/stun.go @@ -0,0 +1,224 @@ +package realm + +import ( + "context" + "errors" + "net" + "net/netip" + "slices" + "strconv" + "strings" + "time" + + "github.com/pion/stun/v3" +) + +const ( + defaultSTUNTimeout = 4 * time.Second + defaultPunchTimeout = 10 * time.Second + defaultPunchInterval = 100 * time.Millisecond + + symmetricNATPortGap = 4 + symmetricNATExtraPorts = 4 + symmetricNATMaxPortsPerHost = 32 +) + +func resolveSTUNServers(local net.IP, servers []string) []*net.UDPAddr { + var network string + if local.IsUnspecified() { + network = "ip" + } else { + if local.To4() != nil { + network = "ip4" + } else { + network = "ip6" + } + } + + var seen = make(map[string]struct{}) + var addrs = make([]*net.UDPAddr, 0, len(servers)) + for _, server := range servers { + h, p, err := net.SplitHostPort(server) + if err != nil { + continue + } + port, err := strconv.Atoi(p) + if err != nil { + continue + } + ips, err := net.DefaultResolver.LookupIP(context.Background(), network, h) + if err != nil { + continue + } + for _, ip := range ips { + if _, ok := seen[net.JoinHostPort(ip.String(), p)]; !ok { + seen[net.JoinHostPort(ip.String(), p)] = struct{}{} + addrs = append(addrs, &net.UDPAddr{IP: ip, Port: port}) + break + } + } + } + + return addrs +} + +func parseSTUNBindingResponse(packet []byte) (*stun.Message, netip.AddrPort, error) { + msg := stun.New() + if err := stun.Decode(packet, msg); err != nil { + return nil, netip.AddrPort{}, err + } + if msg.Type != stun.BindingSuccess { + return nil, netip.AddrPort{}, errors.New("not a STUN binding success response") + } + + var xorMapped stun.XORMappedAddress + if err := xorMapped.GetFrom(msg); err == nil { + addr, err := netIPPortToAddrPort(xorMapped.IP, xorMapped.Port) + return msg, addr, err + } + + var mapped stun.MappedAddress + if err := mapped.GetFrom(msg); err == nil { + addr, err := netIPPortToAddrPort(mapped.IP, mapped.Port) + return msg, addr, err + } + + return nil, netip.AddrPort{}, errors.New("STUN mapped address not found") +} + +func netIPPortToAddrPort(ip net.IP, port int) (netip.AddrPort, error) { + if port <= 0 || port > 65535 { + return netip.AddrPort{}, errors.New("invalid STUN mapped port") + } + if ip4 := ip.To4(); ip4 != nil { + var addr [4]byte + copy(addr[:], ip4) + return netip.AddrPortFrom(netip.AddrFrom4(addr), uint16(port)), nil + } + ip16 := ip.To16() + if ip16 == nil { + return netip.AddrPort{}, errors.New("invalid STUN mapped IP") + } + var addr [16]byte + copy(addr[:], ip16) + return netip.AddrPortFrom(netip.AddrFrom16(addr), uint16(port)), nil +} + +func candidatePunchAddrs(locals, peers []netip.AddrPort) ([]netip.AddrPort, map[netip.AddrPort]struct{}) { + var allow4, allow6 bool + for _, local := range locals { + if local.Addr().Is4() { + allow4 = true + } else { + allow6 = true + } + if allow4 && allow6 { + break + } + } + var seen = make(map[netip.AddrPort]struct{}, len(peers)) + var candidates = make([]netip.AddrPort, 0, len(peers)) + for _, peer := range peers { + if _, ok := seen[peer]; ok { + continue + } + if peer.IsValid() { + if peer.Addr().Is4() { + if allow4 { + seen[peer] = struct{}{} + candidates = append(candidates, peer) + } + } else { + if allow6 { + seen[peer] = struct{}{} + candidates = append(candidates, peer) + } + } + } + } + return candidates, seen +} + +func expandSymmetricNATCandidates(candidates []netip.AddrPort, seen map[netip.AddrPort]struct{}) []netip.AddrPort { + portsByIP := make(map[netip.Addr][]uint16) + for _, addr := range candidates { + if addr.Addr().Is4() { + portsByIP[addr.Addr()] = append(portsByIP[addr.Addr()], addr.Port()) + } + } + for ip, ports := range portsByIP { + ports = uniqueSortedPorts(ports) + if !predictablePortGroup(ports) { + continue + } + start := int(ports[0]) + end := int(ports[len(ports)-1]) + symmetricNATExtraPorts + if end > 65535 { + end = 65535 + } + added := 0 + for port := start; port <= end && added < symmetricNATMaxPortsPerHost; port++ { + addr := netip.AddrPortFrom(ip, uint16(port)) + if _, ok := seen[addr]; ok { + continue + } + seen[addr] = struct{}{} + candidates = append(candidates, addr) + added++ + } + } + sortAddrPorts(candidates) + return candidates +} + +func uniqueSortedPorts(ports []uint16) []uint16 { + slices.Sort(ports) + out := ports[:0] + var last uint16 + for i, port := range ports { + if i > 0 && port == last { + continue + } + out = append(out, port) + last = port + } + return out +} + +func predictablePortGroup(ports []uint16) bool { + if len(ports) < 2 { + return false + } + for i := 1; i < len(ports); i++ { + if ports[i]-ports[i-1] > symmetricNATPortGap { + return false + } + } + return true +} + +func sortAddrPorts(addrs []netip.AddrPort) { + slices.SortFunc(addrs, func(a, b netip.AddrPort) int { + return strings.Compare(a.String(), b.String()) + }) +} + +func addrPortStrings(addrs []netip.AddrPort) []string { + out := make([]string, 0, len(addrs)) + for _, addr := range addrs { + out = append(out, addr.String()) + } + return out +} + +func parseAddrPorts(addrs []string) ([]netip.AddrPort, error) { + out := make([]netip.AddrPort, 0, len(addrs)) + for _, s := range addrs { + addr, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + out = append(out, addr) + } + return out, nil +} diff --git a/transport/internet/hysteria/congestion/bbr/bandwidth_sampler.go b/transport/internet/hysteria/congestion/bbr/bandwidth_sampler.go index 35b63c06..2bd66d09 100644 --- a/transport/internet/hysteria/congestion/bbr/bandwidth_sampler.go +++ b/transport/internet/hysteria/congestion/bbr/bandwidth_sampler.go @@ -801,7 +801,8 @@ func (b *bandwidthSampler) onPacketAcknowledged(ackTime monotime.Time, packetNum if sentPacketPointer.sentTime.After(sentPacketPointer.lastAckedPacketSentTime) { sendRate = BandwidthFromDelta( sentPacketPointer.sendTimeState.totalBytesSent-sentPacketPointer.totalBytesSentAtLastAckedPacket, - sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime)) + sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime), + ) } var a0 ackPoint @@ -848,7 +849,8 @@ func (b *bandwidthSampler) onAckEventEnd( b.lastSentPacket, b.lastAckedPacket, b.lastAckedPacketAckTime, - newlyAckedBytes) + newlyAckedBytes, + ) // If |extra_acked| is zero, i.e. this ack event marks the start of a new ack // aggregation epoch, save LessRecentPoint, which is the last ack point of the // previous epoch, as a A0 candidate. diff --git a/transport/internet/hysteria/congestion/bbr/bbr_sender.go b/transport/internet/hysteria/congestion/bbr/bbr_sender.go index bcbf8133..26e7bb7f 100644 --- a/transport/internet/hysteria/congestion/bbr/bbr_sender.go +++ b/transport/internet/hysteria/congestion/bbr/bbr_sender.go @@ -640,7 +640,8 @@ func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, even func (b *bbrSender) PacingRate() Bandwidth { if b.pacingRate == 0 { return Bandwidth(b.highGain * float64( - BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt()))) + BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt()), + )) } return b.pacingRate