mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-07-03 02:08:45 +00:00
Finalmask: Add Realm (UDP hole punching in Hysteria v2.9.1) (#6137)
https://github.com/XTLS/Xray-core/pull/5657#issuecomment-4446406536 https://github.com/XTLS/Xray-core/pull/6137#issuecomment-4469822775 Example: https://github.com/XTLS/Xray-core/pull/6137#issue-4454013510
This commit is contained in:
@@ -0,0 +1,171 @@
|
||||
package realm
|
||||
|
||||
import (
|
||||
"context"
|
||||
goerrors "errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pion/stun/v3"
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
)
|
||||
|
||||
type realmConnClient struct {
|
||||
net.PacketConn
|
||||
peer *net.UDPAddr
|
||||
|
||||
realmClient *Client
|
||||
realmID string
|
||||
stunServers []string
|
||||
stunTimeout time.Duration
|
||||
punchTimeout time.Duration
|
||||
punchInterval time.Duration
|
||||
}
|
||||
|
||||
func NewConnClient(config *Config, raw net.PacketConn) (net.PacketConn, error) {
|
||||
conn := &realmConnClient{
|
||||
PacketConn: raw,
|
||||
|
||||
realmClient: NewClient(config.Scheme, config.Host, config.Port, config.Token, config.TlsConfig),
|
||||
realmID: config.ID,
|
||||
stunServers: config.StunServers,
|
||||
stunTimeout: defaultSTUNTimeout,
|
||||
punchTimeout: defaultPunchTimeout,
|
||||
punchInterval: defaultPunchInterval,
|
||||
}
|
||||
|
||||
return conn.getpeer()
|
||||
}
|
||||
|
||||
func (c *realmConnClient) getpeer() (net.PacketConn, error) {
|
||||
start := time.Now()
|
||||
servers := resolveSTUNServers(c.PacketConn.LocalAddr().(*net.UDPAddr).IP, c.stunServers)
|
||||
errors.LogDebug(context.Background(), "[realm] update stun servers ", servers, " with ", time.Since(start))
|
||||
if len(servers) == 0 {
|
||||
return nil, errors.New("empty locals")
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
locals := c.discover(servers)
|
||||
errors.LogDebug(context.Background(), "[realm] update stun locals ", locals, " with ", time.Since(start))
|
||||
if len(locals) == 0 {
|
||||
return nil, errors.New("empty locals")
|
||||
}
|
||||
|
||||
meta := common.Must2(NewPunchMetadata())
|
||||
|
||||
start = time.Now()
|
||||
resp, err := c.realmClient.Connect(context.Background(), c.realmID, ConnectRequest{
|
||||
Addresses: addrPortStrings(locals),
|
||||
PunchMetadata: meta,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
errors.LogDebug(context.Background(), "[realm] ", c.realmID, " ", meta.Nonce, " connect ", resp.Addresses, " with ", time.Since(start))
|
||||
|
||||
peers, _ := parseAddrPorts(resp.Addresses)
|
||||
errors.LogDebug(context.Background(), "[realm] update peers ", peers)
|
||||
filteredPeers, seen := candidatePunchAddrs(locals, peers)
|
||||
errors.LogDebug(context.Background(), "[realm] filtered peers ", filteredPeers)
|
||||
expandedPeers := expandSymmetricNATCandidates(filteredPeers, seen)
|
||||
errors.LogDebug(context.Background(), "[realm] expanded peers ", expandedPeers)
|
||||
|
||||
if len(expandedPeers) == 0 {
|
||||
return nil, errors.New("empty peers")
|
||||
}
|
||||
|
||||
start = time.Now()
|
||||
peer, err := c.punch(meta, peers)
|
||||
if err != nil {
|
||||
return nil, errors.New("punch fail").Base(err)
|
||||
}
|
||||
errors.LogDebug(context.Background(), "[realm] punch peer ", peer, " with ", time.Since(start))
|
||||
|
||||
c.peer = peer
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *realmConnClient) discover(servers []*net.UDPAddr) []netip.AddrPort {
|
||||
var transactionIDs = make(map[[stun.TransactionIDSize]byte]struct{}, len(servers))
|
||||
for _, server := range servers {
|
||||
msg := common.Must2(stun.Build(stun.TransactionID, stun.BindingRequest))
|
||||
transactionIDs[msg.TransactionID] = struct{}{}
|
||||
_, _ = c.PacketConn.WriteTo(msg.Raw, server)
|
||||
}
|
||||
|
||||
var buf = make([]byte, 1500)
|
||||
var results = make([]netip.AddrPort, 0, len(servers))
|
||||
c.PacketConn.SetReadDeadline(time.Now().Add(defaultSTUNTimeout))
|
||||
for len(transactionIDs) > 0 {
|
||||
n, _, err := c.PacketConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
msg, addrPort, err := parseSTUNBindingResponse(buf[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if _, ok := transactionIDs[msg.TransactionID]; ok {
|
||||
delete(transactionIDs, msg.TransactionID)
|
||||
results = append(results, addrPort)
|
||||
}
|
||||
}
|
||||
c.PacketConn.SetReadDeadline(time.Time{})
|
||||
slices.SortFunc(results, func(a, b netip.AddrPort) int {
|
||||
return strings.Compare(a.String(), b.String())
|
||||
})
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func (c *realmConnClient) punch(meta PunchMetadata, peers []netip.AddrPort) (*net.UDPAddr, error) {
|
||||
defer c.PacketConn.SetReadDeadline(time.Time{})
|
||||
nextSend := time.Now()
|
||||
deadline := nextSend.Add(c.punchTimeout)
|
||||
buf := make([]byte, punchMaxWireLen)
|
||||
for {
|
||||
now := time.Now()
|
||||
if now.After(deadline) {
|
||||
return nil, errors.New("timeout")
|
||||
}
|
||||
if now.After(nextSend) {
|
||||
for _, peer := range peers {
|
||||
packet := common.Must2(EncodePunchPacket(PunchPacketHello, meta))
|
||||
_, _ = c.PacketConn.WriteTo(packet, net.UDPAddrFromAddrPort(peer))
|
||||
}
|
||||
nextSend = now.Add(c.punchInterval)
|
||||
}
|
||||
|
||||
if nextSend.After(deadline) {
|
||||
c.PacketConn.SetReadDeadline(deadline)
|
||||
} else {
|
||||
c.PacketConn.SetReadDeadline(nextSend)
|
||||
}
|
||||
n, addr, err := c.PacketConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
var netErr net.Error
|
||||
if goerrors.As(err, &netErr) && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
packet, err := DecodePunchPacket(buf[:n], meta)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if packet.Type == PunchPacketHello {
|
||||
packet := common.Must2(EncodePunchPacket(PunchPacketAck, meta))
|
||||
_, _ = c.PacketConn.WriteTo(packet, addr)
|
||||
}
|
||||
return addr.(*net.UDPAddr), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *realmConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return c.PacketConn.WriteTo(p, c.peer)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package realm
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
"github.com/xtls/xray-core/transport/internet/hysteria/udphop"
|
||||
)
|
||||
|
||||
func (c *Config) UDP() {}
|
||||
|
||||
func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
|
||||
_, ok1 := raw.(*internet.FakePacketConn)
|
||||
_, ok2 := raw.(*udphop.UdpHopPacketConn)
|
||||
if level != 0 || ok1 || ok2 {
|
||||
return nil, errors.New("realm requires being at the outermost level")
|
||||
}
|
||||
return NewConnClient(c, raw)
|
||||
}
|
||||
|
||||
func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) {
|
||||
if level != 0 {
|
||||
return nil, errors.New("realm requires being at the outermost level")
|
||||
}
|
||||
return NewConnServer(c, raw)
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.11
|
||||
// protoc v6.33.5
|
||||
// source: transport/internet/finalmask/realm/config.proto
|
||||
|
||||
package realm
|
||||
|
||||
import (
|
||||
tls "github.com/xtls/xray-core/transport/internet/tls"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Scheme string `protobuf:"bytes,1,opt,name=scheme,proto3" json:"scheme,omitempty"`
|
||||
Host string `protobuf:"bytes,2,opt,name=host,proto3" json:"host,omitempty"`
|
||||
Port string `protobuf:"bytes,3,opt,name=port,proto3" json:"port,omitempty"`
|
||||
Token string `protobuf:"bytes,4,opt,name=token,proto3" json:"token,omitempty"`
|
||||
ID string `protobuf:"bytes,5,opt,name=ID,proto3" json:"ID,omitempty"`
|
||||
StunServers []string `protobuf:"bytes,6,rep,name=stun_servers,json=stunServers,proto3" json:"stun_servers,omitempty"`
|
||||
TlsConfig *tls.Config `protobuf:"bytes,7,opt,name=tls_config,json=tlsConfig,proto3" json:"tls_config,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *Config) Reset() {
|
||||
*x = Config{}
|
||||
mi := &file_transport_internet_finalmask_realm_config_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *Config) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Config) ProtoMessage() {}
|
||||
|
||||
func (x *Config) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_transport_internet_finalmask_realm_config_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
|
||||
func (*Config) Descriptor() ([]byte, []int) {
|
||||
return file_transport_internet_finalmask_realm_config_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *Config) GetScheme() string {
|
||||
if x != nil {
|
||||
return x.Scheme
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Config) GetHost() string {
|
||||
if x != nil {
|
||||
return x.Host
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Config) GetPort() string {
|
||||
if x != nil {
|
||||
return x.Port
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Config) GetToken() string {
|
||||
if x != nil {
|
||||
return x.Token
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Config) GetID() string {
|
||||
if x != nil {
|
||||
return x.ID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Config) GetStunServers() []string {
|
||||
if x != nil {
|
||||
return x.StunServers
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Config) GetTlsConfig() *tls.Config {
|
||||
if x != nil {
|
||||
return x.TlsConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_transport_internet_finalmask_realm_config_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_transport_internet_finalmask_realm_config_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"/transport/internet/finalmask/realm/config.proto\x12'xray.transport.internet.finalmask.realm\x1a#transport/internet/tls/config.proto\"\xd5\x01\n" +
|
||||
"\x06Config\x12\x16\n" +
|
||||
"\x06scheme\x18\x01 \x01(\tR\x06scheme\x12\x12\n" +
|
||||
"\x04host\x18\x02 \x01(\tR\x04host\x12\x12\n" +
|
||||
"\x04port\x18\x03 \x01(\tR\x04port\x12\x14\n" +
|
||||
"\x05token\x18\x04 \x01(\tR\x05token\x12\x0e\n" +
|
||||
"\x02ID\x18\x05 \x01(\tR\x02ID\x12!\n" +
|
||||
"\fstun_servers\x18\x06 \x03(\tR\vstunServers\x12B\n" +
|
||||
"\n" +
|
||||
"tls_config\x18\a \x01(\v2#.xray.transport.internet.tls.ConfigR\ttlsConfigB\x97\x01\n" +
|
||||
"+com.xray.transport.internet.finalmask.realmP\x01Z<github.com/xtls/xray-core/transport/internet/finalmask/realm\xaa\x02'Xray.Transport.Internet.Finalmask.Realmb\x06proto3"
|
||||
|
||||
var (
|
||||
file_transport_internet_finalmask_realm_config_proto_rawDescOnce sync.Once
|
||||
file_transport_internet_finalmask_realm_config_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_transport_internet_finalmask_realm_config_proto_rawDescGZIP() []byte {
|
||||
file_transport_internet_finalmask_realm_config_proto_rawDescOnce.Do(func() {
|
||||
file_transport_internet_finalmask_realm_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_realm_config_proto_rawDesc), len(file_transport_internet_finalmask_realm_config_proto_rawDesc)))
|
||||
})
|
||||
return file_transport_internet_finalmask_realm_config_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_transport_internet_finalmask_realm_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_transport_internet_finalmask_realm_config_proto_goTypes = []any{
|
||||
(*Config)(nil), // 0: xray.transport.internet.finalmask.realm.Config
|
||||
(*tls.Config)(nil), // 1: xray.transport.internet.tls.Config
|
||||
}
|
||||
var file_transport_internet_finalmask_realm_config_proto_depIdxs = []int32{
|
||||
1, // 0: xray.transport.internet.finalmask.realm.Config.tls_config:type_name -> xray.transport.internet.tls.Config
|
||||
1, // [1:1] is the sub-list for method output_type
|
||||
1, // [1:1] is the sub-list for method input_type
|
||||
1, // [1:1] is the sub-list for extension type_name
|
||||
1, // [1:1] is the sub-list for extension extendee
|
||||
0, // [0:1] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_transport_internet_finalmask_realm_config_proto_init() }
|
||||
func file_transport_internet_finalmask_realm_config_proto_init() {
|
||||
if File_transport_internet_finalmask_realm_config_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_realm_config_proto_rawDesc), len(file_transport_internet_finalmask_realm_config_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_transport_internet_finalmask_realm_config_proto_goTypes,
|
||||
DependencyIndexes: file_transport_internet_finalmask_realm_config_proto_depIdxs,
|
||||
MessageInfos: file_transport_internet_finalmask_realm_config_proto_msgTypes,
|
||||
}.Build()
|
||||
File_transport_internet_finalmask_realm_config_proto = out.File
|
||||
file_transport_internet_finalmask_realm_config_proto_goTypes = nil
|
||||
file_transport_internet_finalmask_realm_config_proto_depIdxs = nil
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package xray.transport.internet.finalmask.realm;
|
||||
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Realm";
|
||||
option go_package = "github.com/xtls/xray-core/transport/internet/finalmask/realm";
|
||||
option java_package = "com.xray.transport.internet.finalmask.realm";
|
||||
option java_multiple_files = true;
|
||||
|
||||
import "transport/internet/tls/config.proto";
|
||||
|
||||
message Config {
|
||||
string scheme = 1;
|
||||
string host = 2;
|
||||
string port = 3;
|
||||
string token = 4;
|
||||
string ID = 5;
|
||||
repeated string stun_servers = 6;
|
||||
xray.transport.internet.tls.Config tls_config = 7;
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package realm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/xtls/xray-core/transport/internet/tls"
|
||||
)
|
||||
|
||||
const maxErrorBodySize = 64 * 1024
|
||||
|
||||
const (
|
||||
PunchNonceSize = 16
|
||||
PunchObfsKeySize = 32
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
scheme string
|
||||
hostport string
|
||||
token string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
SessionID string `json:"session_id"`
|
||||
TTL int `json:"ttl"`
|
||||
}
|
||||
|
||||
type HeartbeatResponse struct {
|
||||
TTL int `json:"ttl"`
|
||||
}
|
||||
|
||||
type HeartbeatRequest struct {
|
||||
Addresses []string `json:"addresses,omitempty"`
|
||||
}
|
||||
|
||||
type PunchMetadata struct {
|
||||
Nonce string `json:"nonce"`
|
||||
Obfs string `json:"obfs"`
|
||||
}
|
||||
|
||||
type ConnectRequest struct {
|
||||
Addresses []string `json:"addresses"`
|
||||
PunchMetadata
|
||||
}
|
||||
|
||||
type ConnectResponse struct {
|
||||
Addresses []string `json:"addresses"`
|
||||
PunchMetadata
|
||||
}
|
||||
|
||||
type PunchEvent struct {
|
||||
Addresses []string `json:"addresses"`
|
||||
PunchMetadata
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type StatusError struct {
|
||||
StatusCode int
|
||||
Response ErrorResponse
|
||||
}
|
||||
|
||||
func (e *StatusError) Error() string {
|
||||
if e.Response.Error != "" || e.Response.Message != "" {
|
||||
return fmt.Sprintf("realm server returned %d: %s: %s", e.StatusCode, e.Response.Error, e.Response.Message)
|
||||
}
|
||||
return fmt.Sprintf("realm server returned %d", e.StatusCode)
|
||||
}
|
||||
|
||||
func NewClient(scheme, host, port, token string, tlsConfig *tls.Config) *Client {
|
||||
client := http.DefaultClient
|
||||
if tlsConfig != nil {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.TLSClientConfig = tlsConfig.GetTLSConfig()
|
||||
client = &http.Client{Transport: tr}
|
||||
}
|
||||
return &Client{
|
||||
scheme: scheme,
|
||||
hostport: net.JoinHostPort(host, port),
|
||||
token: token,
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
func NewPunchMetadata() (PunchMetadata, error) {
|
||||
nonce, err := randHex(PunchNonceSize)
|
||||
if err != nil {
|
||||
return PunchMetadata{}, err
|
||||
}
|
||||
obfs, err := randHex(PunchObfsKeySize)
|
||||
if err != nil {
|
||||
return PunchMetadata{}, err
|
||||
}
|
||||
return PunchMetadata{
|
||||
Nonce: nonce,
|
||||
Obfs: obfs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) Register(ctx context.Context, realmID string, addresses []string) (*RegisterResponse, error) {
|
||||
var resp RegisterResponse
|
||||
if err := c.doJSON(ctx, http.MethodPost, realmID, "", c.token, addressRequest{Addresses: addresses}, http.StatusOK, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) Deregister(ctx context.Context, realmID, sessionID string) error {
|
||||
return c.doJSON(ctx, http.MethodDelete, realmID, "", sessionID, nil, http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
func (c *Client) Heartbeat(ctx context.Context, realmID, sessionID string, req HeartbeatRequest) (*HeartbeatResponse, error) {
|
||||
var resp HeartbeatResponse
|
||||
if err := c.doJSON(ctx, http.MethodPost, realmID, "heartbeat", sessionID, req, http.StatusOK, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) Connect(ctx context.Context, realmID string, req ConnectRequest) (*ConnectResponse, error) {
|
||||
var resp ConnectResponse
|
||||
if err := c.doJSON(ctx, http.MethodPost, realmID, "connect", c.token, req, http.StatusOK, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
type ConnectResponseRequest struct {
|
||||
Addresses []string `json:"addresses"`
|
||||
}
|
||||
|
||||
func (c *Client) ConnectResponse(ctx context.Context, realmID, sessionID, nonce string, addresses []string) error {
|
||||
subPath := "connects/" + url.PathEscape(nonce)
|
||||
return c.doJSON(ctx, http.MethodPost, realmID, subPath, sessionID,
|
||||
ConnectResponseRequest{Addresses: addresses}, http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
func (c *Client) Events(ctx context.Context, realmID, sessionID string) (*EventStream, error) {
|
||||
req, err := c.newRequest(ctx, http.MethodGet, realmID, "events", sessionID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
return nil, decodeStatusError(resp)
|
||||
}
|
||||
return newEventStream(resp), nil
|
||||
}
|
||||
|
||||
type addressRequest struct {
|
||||
Addresses []string `json:"addresses"`
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(ctx context.Context, method, realmID, subPath, token string, in any, expectedStatus int, out any) error {
|
||||
var body io.Reader
|
||||
if in != nil {
|
||||
bs, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = bytes.NewReader(bs)
|
||||
}
|
||||
req, err := c.newRequest(ctx, method, realmID, subPath, token, body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if in != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != expectedStatus {
|
||||
return decodeStatusError(resp)
|
||||
}
|
||||
if out == nil || resp.Body == nil {
|
||||
return nil
|
||||
}
|
||||
return json.NewDecoder(resp.Body).Decode(out)
|
||||
}
|
||||
|
||||
func (c *Client) newRequest(ctx context.Context, method, realmID, subPath, token string, body io.Reader) (*http.Request, error) {
|
||||
u := &url.URL{
|
||||
Scheme: c.scheme,
|
||||
Host: c.hostport,
|
||||
Path: joinURLPath("v1", url.PathEscape(realmID), subPath),
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, u.String(), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func randHex(size int) (string, error) {
|
||||
b := make([]byte, size)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func joinURLPath(parts ...string) string {
|
||||
var joined []string
|
||||
for _, part := range parts {
|
||||
part = strings.Trim(part, "/")
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
joined = append(joined, part)
|
||||
}
|
||||
return "/" + strings.Join(joined, "/")
|
||||
}
|
||||
|
||||
func decodeStatusError(resp *http.Response) error {
|
||||
var errResp ErrorResponse
|
||||
_ = json.NewDecoder(io.LimitReader(resp.Body, maxErrorBodySize)).Decode(&errResp)
|
||||
return &StatusError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Response: errResp,
|
||||
}
|
||||
}
|
||||
|
||||
type EventStream struct {
|
||||
resp *http.Response
|
||||
scanner *bufio.Scanner
|
||||
}
|
||||
|
||||
func newEventStream(resp *http.Response) *EventStream {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 1024), 1024*1024)
|
||||
return &EventStream{
|
||||
resp: resp,
|
||||
scanner: scanner,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *EventStream) Close() error {
|
||||
return s.resp.Body.Close()
|
||||
}
|
||||
|
||||
func (s *EventStream) Next() (*PunchEvent, error) {
|
||||
var eventName string
|
||||
var data strings.Builder
|
||||
for s.scanner.Scan() {
|
||||
line := s.scanner.Text()
|
||||
if line == "" {
|
||||
if eventName == "" && data.Len() == 0 {
|
||||
continue
|
||||
}
|
||||
if eventName != "punch" {
|
||||
eventName = ""
|
||||
data.Reset()
|
||||
continue
|
||||
}
|
||||
var ev PunchEvent
|
||||
if err := json.Unmarshal([]byte(data.String()), &ev); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ev, nil
|
||||
}
|
||||
if strings.HasPrefix(line, ":") {
|
||||
continue
|
||||
}
|
||||
field, value, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
value = strings.TrimPrefix(value, " ")
|
||||
switch field {
|
||||
case "event":
|
||||
eventName = value
|
||||
case "data":
|
||||
if data.Len() > 0 {
|
||||
data.WriteByte('\n')
|
||||
}
|
||||
data.WriteString(value)
|
||||
}
|
||||
}
|
||||
if err := s.scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package realm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxPunchPadding = 1024
|
||||
|
||||
punchSaltLen = 8
|
||||
// Plain punch payload before obfs:
|
||||
// 8-byte magic, 1-byte type, 16-byte nonce, then 0..1024 random padding bytes.
|
||||
punchHeaderLen = 25
|
||||
punchMinWireLen = punchSaltLen + punchHeaderLen
|
||||
punchMaxWireLen = punchMinWireLen + MaxPunchPadding
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidPunchPacket = errors.New("invalid punch packet")
|
||||
|
||||
punchMagic = [8]byte{'H', 'Y', 'R', 'L', 'M', 'v', '1', 0}
|
||||
)
|
||||
|
||||
type PunchPacketType byte
|
||||
|
||||
const (
|
||||
PunchPacketHello PunchPacketType = 0x01
|
||||
PunchPacketAck PunchPacketType = 0x02
|
||||
)
|
||||
|
||||
type PunchPacket struct {
|
||||
Type PunchPacketType
|
||||
PaddingLength int
|
||||
}
|
||||
|
||||
func EncodePunchPacket(packetType PunchPacketType, meta PunchMetadata) ([]byte, error) {
|
||||
if !validPunchPacketType(packetType) {
|
||||
return nil, fmt.Errorf("%w: unknown packet type", ErrInvalidPunchPacket)
|
||||
}
|
||||
nonce, obfsKey, err := decodePunchMetadata(meta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddingLength, err := randomPaddingLength()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plain := make([]byte, punchHeaderLen+paddingLength)
|
||||
copy(plain[:len(punchMagic)], punchMagic[:])
|
||||
plain[len(punchMagic)] = byte(packetType)
|
||||
copy(plain[len(punchMagic)+1:punchHeaderLen], nonce)
|
||||
if paddingLength > 0 {
|
||||
if _, err := rand.Read(plain[punchHeaderLen:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
packet := make([]byte, punchSaltLen+len(plain))
|
||||
if _, err := rand.Read(packet[:punchSaltLen]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(packet[punchSaltLen:], plain)
|
||||
xorPunchPacket(packet[punchSaltLen:], obfsKey, packet[:punchSaltLen])
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func DecodePunchPacket(packet []byte, meta PunchMetadata) (PunchPacket, error) {
|
||||
if len(packet) < punchMinWireLen {
|
||||
return PunchPacket{}, fmt.Errorf("%w: packet too short", ErrInvalidPunchPacket)
|
||||
}
|
||||
if len(packet) > punchMaxWireLen {
|
||||
return PunchPacket{}, fmt.Errorf("%w: packet too long", ErrInvalidPunchPacket)
|
||||
}
|
||||
nonce, obfsKey, err := decodePunchMetadata(meta)
|
||||
if err != nil {
|
||||
return PunchPacket{}, err
|
||||
}
|
||||
salt := packet[:punchSaltLen]
|
||||
plain := append([]byte(nil), packet[punchSaltLen:]...)
|
||||
xorPunchPacket(plain, obfsKey, salt)
|
||||
if !bytes.Equal(plain[:len(punchMagic)], punchMagic[:]) {
|
||||
return PunchPacket{}, fmt.Errorf("%w: bad magic", ErrInvalidPunchPacket)
|
||||
}
|
||||
packetType := PunchPacketType(plain[len(punchMagic)])
|
||||
if !validPunchPacketType(packetType) {
|
||||
return PunchPacket{}, fmt.Errorf("%w: unknown packet type", ErrInvalidPunchPacket)
|
||||
}
|
||||
if !bytes.Equal(plain[len(punchMagic)+1:punchHeaderLen], nonce) {
|
||||
return PunchPacket{}, fmt.Errorf("%w: nonce mismatch", ErrInvalidPunchPacket)
|
||||
}
|
||||
return PunchPacket{
|
||||
Type: packetType,
|
||||
PaddingLength: len(plain) - punchHeaderLen,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodePunchMetadata(meta PunchMetadata) (nonce, obfsKey []byte, err error) {
|
||||
nonce, err = decodeHexSize("nonce", meta.Nonce, PunchNonceSize)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
obfsKey, err = decodeHexSize("obfs", meta.Obfs, PunchObfsKeySize)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return nonce, obfsKey, nil
|
||||
}
|
||||
|
||||
func decodeHexSize(name, value string, size int) ([]byte, error) {
|
||||
b, err := hex.DecodeString(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid %s", ErrInvalidPunchPacket, name)
|
||||
}
|
||||
if len(b) != size {
|
||||
return nil, fmt.Errorf("%w: invalid %s length", ErrInvalidPunchPacket, name)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func randomPaddingLength() (int, error) {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(MaxPunchPadding+1))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(n.Int64()), nil
|
||||
}
|
||||
|
||||
func xorPunchPacket(packet, obfsKey, salt []byte) {
|
||||
h := sha256.New()
|
||||
_, _ = h.Write(obfsKey)
|
||||
_, _ = h.Write(salt)
|
||||
mask := h.Sum(nil)
|
||||
for i := range packet {
|
||||
packet[i] ^= mask[i%len(mask)]
|
||||
}
|
||||
}
|
||||
|
||||
func validPunchPacketType(packetType PunchPacketType) bool {
|
||||
return packetType == PunchPacketHello || packetType == PunchPacketAck
|
||||
}
|
||||
@@ -0,0 +1,401 @@
|
||||
package realm
|
||||
|
||||
import (
|
||||
"context"
|
||||
go_errors "errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/stun/v3"
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
)
|
||||
|
||||
const defaultEventBuffer = 16
|
||||
const defaultStunCacheTTL = time.Second * 10
|
||||
const defaultHeartbeatInterval = time.Second * 15
|
||||
|
||||
type PunchPacketEvent struct {
|
||||
Addr netip.AddrPort
|
||||
Packet PunchPacket
|
||||
}
|
||||
|
||||
type STUNPacketEvent struct {
|
||||
Message *stun.Message
|
||||
Addr netip.AddrPort
|
||||
}
|
||||
|
||||
type realmConnServer struct {
|
||||
cleaned chan struct{}
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
net.PacketConn
|
||||
|
||||
realmClient *Client
|
||||
realmID string
|
||||
stunServers []string
|
||||
stunTimeout time.Duration
|
||||
punchTimeout time.Duration
|
||||
punchInterval time.Duration
|
||||
|
||||
events map[PunchMetadata]chan PunchPacketEvent
|
||||
stun chan STUNPacketEvent
|
||||
mu sync.Mutex
|
||||
|
||||
locals []netip.AddrPort
|
||||
localsMu sync.Mutex
|
||||
localsLast time.Time
|
||||
}
|
||||
|
||||
func NewConnServer(config *Config, raw net.PacketConn) (net.PacketConn, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
conn := &realmConnServer{
|
||||
cleaned: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
PacketConn: raw,
|
||||
|
||||
realmClient: NewClient(config.Scheme, config.Host, config.Port, config.Token, config.TlsConfig),
|
||||
realmID: config.ID,
|
||||
stunServers: config.StunServers,
|
||||
stunTimeout: defaultSTUNTimeout,
|
||||
punchTimeout: defaultPunchTimeout,
|
||||
punchInterval: defaultPunchInterval,
|
||||
|
||||
events: make(map[PunchMetadata]chan PunchPacketEvent),
|
||||
stun: make(chan STUNPacketEvent, defaultEventBuffer),
|
||||
}
|
||||
|
||||
go conn.run()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *realmConnServer) addSTUN(packet []byte) bool {
|
||||
if !stun.IsMessage(packet) {
|
||||
return false
|
||||
}
|
||||
msg, addr, err := parseSTUNBindingResponse(packet)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case c.stun <- STUNPacketEvent{Message: msg, Addr: addr}:
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *realmConnServer) addPunch(packet []byte, addr net.Addr) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for meta, ch := range c.events {
|
||||
punchPacket, err := DecodePunchPacket(packet, meta)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case ch <- PunchPacketEvent{
|
||||
Addr: addr.(*net.UDPAddr).AddrPort(),
|
||||
Packet: punchPacket,
|
||||
}:
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *realmConnServer) waitctx(ctx context.Context, t time.Duration) bool {
|
||||
timer := time.NewTimer(t)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-timer.C:
|
||||
return false
|
||||
case <-ctx.Done():
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (c *realmConnServer) discover(servers []*net.UDPAddr) []netip.AddrPort {
|
||||
var transactionIDs = make(map[[stun.TransactionIDSize]byte]struct{}, len(servers))
|
||||
for _, server := range servers {
|
||||
msg := common.Must2(stun.Build(stun.TransactionID, stun.BindingRequest))
|
||||
transactionIDs[msg.TransactionID] = struct{}{}
|
||||
_, _ = c.PacketConn.WriteTo(msg.Raw, server)
|
||||
}
|
||||
|
||||
var deadline = time.NewTimer(c.stunTimeout)
|
||||
var results = make([]netip.AddrPort, 0, len(servers))
|
||||
for len(transactionIDs) > 0 {
|
||||
select {
|
||||
case <-deadline.C:
|
||||
goto end
|
||||
case ev := <-c.stun:
|
||||
if _, ok := transactionIDs[ev.Message.TransactionID]; ok {
|
||||
delete(transactionIDs, ev.Message.TransactionID)
|
||||
results = append(results, ev.Addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
end:
|
||||
deadline.Stop()
|
||||
slices.SortFunc(results, func(a, b netip.AddrPort) int {
|
||||
return strings.Compare(a.String(), b.String())
|
||||
})
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
func (c *realmConnServer) getlocals(force bool) []netip.AddrPort {
|
||||
c.localsMu.Lock()
|
||||
if force || time.Since(c.localsLast) > defaultStunCacheTTL {
|
||||
start := time.Now()
|
||||
servers := resolveSTUNServers(c.PacketConn.LocalAddr().(*net.UDPAddr).IP, c.stunServers)
|
||||
errors.LogDebug(context.Background(), "[realm] update stun servers ", servers, " with ", time.Since(start))
|
||||
if len(servers) > 0 {
|
||||
start = time.Now()
|
||||
locals := c.discover(servers)
|
||||
errors.LogDebug(context.Background(), "[realm] update stun locals ", locals, " with ", time.Since(start))
|
||||
if len(locals) > 0 {
|
||||
c.locals = locals
|
||||
c.localsLast = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
locals := append([]netip.AddrPort(nil), c.locals...)
|
||||
c.localsMu.Unlock()
|
||||
return locals
|
||||
}
|
||||
|
||||
func (c *realmConnServer) punch(ctx context.Context, meta PunchMetadata, peers []netip.AddrPort) {
|
||||
c.mu.Lock()
|
||||
if _, ok := c.events[meta]; ok {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
ch := make(chan PunchPacketEvent, defaultEventBuffer)
|
||||
c.events[meta] = ch
|
||||
c.mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
for _, peer := range peers {
|
||||
packet := common.Must2(EncodePunchPacket(PunchPacketHello, meta))
|
||||
_, _ = c.PacketConn.WriteTo(packet, net.UDPAddrFromAddrPort(peer))
|
||||
}
|
||||
deadline := time.NewTimer(c.punchTimeout)
|
||||
ticker := time.NewTicker(c.punchInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
errors.LogDebug(context.Background(), "[realm] punch ", meta.Nonce, " FAIL > session end")
|
||||
goto end
|
||||
case <-deadline.C:
|
||||
errors.LogDebug(context.Background(), "[realm] punch ", meta.Nonce, " FAIL > timeout")
|
||||
goto end
|
||||
case <-ticker.C:
|
||||
for _, peer := range peers {
|
||||
packet := common.Must2(EncodePunchPacket(PunchPacketHello, meta))
|
||||
_, _ = c.PacketConn.WriteTo(packet, net.UDPAddrFromAddrPort(peer))
|
||||
}
|
||||
case event := <-ch:
|
||||
if event.Packet.Type == PunchPacketHello {
|
||||
packet := common.Must2(EncodePunchPacket(PunchPacketAck, meta))
|
||||
_, _ = c.PacketConn.WriteTo(packet, net.UDPAddrFromAddrPort(event.Addr))
|
||||
}
|
||||
errors.LogDebug(context.Background(), "[realm] punch ", meta.Nonce, " SUCCESS ", event.Addr, " with ", time.Since(start))
|
||||
goto end
|
||||
}
|
||||
}
|
||||
end:
|
||||
deadline.Stop()
|
||||
ticker.Stop()
|
||||
|
||||
c.mu.Lock()
|
||||
delete(c.events, meta)
|
||||
close(ch)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *realmConnServer) run() {
|
||||
backoff := time.Second
|
||||
retry:
|
||||
resp, err := c.realmClient.Register(c.ctx, c.realmID, addrPortStrings(c.getlocals(false)))
|
||||
if err != nil {
|
||||
errors.LogErrorInner(context.Background(), err, "[realm] ", c.realmID, " register session err retry in ", backoff)
|
||||
if c.waitctx(c.ctx, backoff) {
|
||||
close(c.cleaned)
|
||||
return
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
goto retry
|
||||
}
|
||||
backoff = time.Second
|
||||
errors.LogDebug(context.Background(), "[realm] ", c.realmID, " sesssion ", resp.SessionID, " ", resp.TTL, " registered")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 2)
|
||||
go c.heartbeatLoop(ctx, resp.SessionID, resp.TTL, errCh)
|
||||
go c.eventsLoop(ctx, resp.SessionID, resp.TTL, errCh)
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
case err = <-errCh:
|
||||
}
|
||||
cancel()
|
||||
errors.LogDebugInner(context.Background(), err, "[realm] session ", resp.SessionID, " end")
|
||||
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
_ = c.realmClient.Deregister(context.Background(), c.realmID, resp.SessionID)
|
||||
errors.LogDebug(context.Background(), "[realm] ", c.realmID, " ", resp.SessionID, " deregistered")
|
||||
close(c.cleaned)
|
||||
return
|
||||
default:
|
||||
goto retry
|
||||
}
|
||||
}
|
||||
|
||||
func (c *realmConnServer) heartbeatLoop(ctx context.Context, sid string, ttl int, errCh chan<- error) {
|
||||
interval := defaultHeartbeatInterval
|
||||
if ttl > 0 {
|
||||
interval = time.Second * time.Duration(ttl) / 2
|
||||
}
|
||||
|
||||
last := time.Now()
|
||||
cur := c.getlocals(false)
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
errCh <- nil
|
||||
return
|
||||
case <-ticker.C:
|
||||
req := HeartbeatRequest{}
|
||||
if new := c.getlocals(false); !slices.Equal(cur, new) {
|
||||
cur = new
|
||||
req.Addresses = addrPortStrings(cur)
|
||||
}
|
||||
start := time.Now()
|
||||
resp, err := c.realmClient.Heartbeat(ctx, c.realmID, sid, req)
|
||||
if err != nil {
|
||||
var statusErr *StatusError
|
||||
if go_errors.As(err, &statusErr) && (statusErr.StatusCode == http.StatusUnauthorized || statusErr.StatusCode == http.StatusNotFound) {
|
||||
errCh <- errors.New("session invalid")
|
||||
return
|
||||
}
|
||||
if time.Since(last) > time.Second*time.Duration(ttl) {
|
||||
errCh <- errors.New("session lost")
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
last = start
|
||||
errors.LogDebug(context.Background(), "[realm] heartbeat ", resp.TTL, " with ", time.Since(start))
|
||||
if resp.TTL > 0 && resp.TTL != ttl {
|
||||
ttl = resp.TTL
|
||||
ticker.Reset(time.Second * time.Duration(ttl) / 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *realmConnServer) eventsLoop(ctx context.Context, sid string, ttl int, errCh chan<- error) {
|
||||
backoff := time.Second
|
||||
last := time.Now()
|
||||
for {
|
||||
start := time.Now()
|
||||
stream, err := c.realmClient.Events(ctx, c.realmID, sid)
|
||||
if err != nil {
|
||||
var statusErr *StatusError
|
||||
if go_errors.As(err, &statusErr) && (statusErr.StatusCode == http.StatusUnauthorized || statusErr.StatusCode == http.StatusNotFound) {
|
||||
errCh <- errors.New("session invalid")
|
||||
return
|
||||
}
|
||||
if time.Since(last) > time.Second*time.Duration(ttl) {
|
||||
errCh <- errors.New("session lost")
|
||||
return
|
||||
}
|
||||
errors.LogDebugInner(context.Background(), err, "[realm] ", sid, " open stream err retry in ", backoff)
|
||||
if c.waitctx(ctx, backoff) {
|
||||
errCh <- nil
|
||||
return
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
continue
|
||||
}
|
||||
backoff = time.Second
|
||||
last = start
|
||||
errors.LogDebug(context.Background(), "[realm] open stream with ", time.Since(start))
|
||||
for {
|
||||
ev, err := stream.Next()
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
break
|
||||
}
|
||||
last = time.Now()
|
||||
go c.punchEvent(ctx, sid, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *realmConnServer) punchEvent(ctx context.Context, sid string, ev *PunchEvent) {
|
||||
errors.LogDebug(context.Background(), "[realm] start punch event ", ev.Nonce, " ", ev.Addresses)
|
||||
|
||||
locals := c.getlocals(false)
|
||||
|
||||
peers, _ := parseAddrPorts(ev.Addresses)
|
||||
errors.LogDebug(context.Background(), "[realm] ", ev.Nonce, " update peers ", peers)
|
||||
filteredPeers, seen := candidatePunchAddrs(locals, peers)
|
||||
errors.LogDebug(context.Background(), "[realm] ", ev.Nonce, " filtered peers ", filteredPeers)
|
||||
expandedPeers := expandSymmetricNATCandidates(filteredPeers, seen)
|
||||
errors.LogDebug(context.Background(), "[realm] ", ev.Nonce, " expanded peers ", expandedPeers)
|
||||
|
||||
if len(expandedPeers) == 0 {
|
||||
errors.LogDebug(context.Background(), "[realm] punch ", ev.Nonce, " FAIL > empty peers")
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
err := c.realmClient.ConnectResponse(ctx, c.realmID, sid, ev.Nonce, addrPortStrings(locals))
|
||||
if err != nil {
|
||||
errors.LogDebugInner(context.Background(), err, "[realm] ", ev.Nonce, " connect response err")
|
||||
}
|
||||
errors.LogDebug(context.Background(), "[realm] ", ev.Nonce, " connect response ", locals, " with ", time.Since(start))
|
||||
|
||||
c.punch(ctx, ev.PunchMetadata, expandedPeers)
|
||||
}
|
||||
|
||||
func (c *realmConnServer) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
for {
|
||||
n, addr, err := c.PacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return n, addr, err
|
||||
}
|
||||
if c.addSTUN(p[:n]) {
|
||||
continue
|
||||
}
|
||||
if c.addPunch(p[:n], addr) {
|
||||
continue
|
||||
}
|
||||
return n, addr, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *realmConnServer) Close() error {
|
||||
c.cancel()
|
||||
<-c.cleaned
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
package realm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pion/stun/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSTUNTimeout = 4 * time.Second
|
||||
defaultPunchTimeout = 10 * time.Second
|
||||
defaultPunchInterval = 100 * time.Millisecond
|
||||
|
||||
symmetricNATPortGap = 4
|
||||
symmetricNATExtraPorts = 4
|
||||
symmetricNATMaxPortsPerHost = 32
|
||||
)
|
||||
|
||||
func resolveSTUNServers(local net.IP, servers []string) []*net.UDPAddr {
|
||||
var network string
|
||||
if local.IsUnspecified() {
|
||||
network = "ip"
|
||||
} else {
|
||||
if local.To4() != nil {
|
||||
network = "ip4"
|
||||
} else {
|
||||
network = "ip6"
|
||||
}
|
||||
}
|
||||
|
||||
var seen = make(map[string]struct{})
|
||||
var addrs = make([]*net.UDPAddr, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
h, p, err := net.SplitHostPort(server)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
port, err := strconv.Atoi(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ips, err := net.DefaultResolver.LookupIP(context.Background(), network, h)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if _, ok := seen[net.JoinHostPort(ip.String(), p)]; !ok {
|
||||
seen[net.JoinHostPort(ip.String(), p)] = struct{}{}
|
||||
addrs = append(addrs, &net.UDPAddr{IP: ip, Port: port})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
func parseSTUNBindingResponse(packet []byte) (*stun.Message, netip.AddrPort, error) {
|
||||
msg := stun.New()
|
||||
if err := stun.Decode(packet, msg); err != nil {
|
||||
return nil, netip.AddrPort{}, err
|
||||
}
|
||||
if msg.Type != stun.BindingSuccess {
|
||||
return nil, netip.AddrPort{}, errors.New("not a STUN binding success response")
|
||||
}
|
||||
|
||||
var xorMapped stun.XORMappedAddress
|
||||
if err := xorMapped.GetFrom(msg); err == nil {
|
||||
addr, err := netIPPortToAddrPort(xorMapped.IP, xorMapped.Port)
|
||||
return msg, addr, err
|
||||
}
|
||||
|
||||
var mapped stun.MappedAddress
|
||||
if err := mapped.GetFrom(msg); err == nil {
|
||||
addr, err := netIPPortToAddrPort(mapped.IP, mapped.Port)
|
||||
return msg, addr, err
|
||||
}
|
||||
|
||||
return nil, netip.AddrPort{}, errors.New("STUN mapped address not found")
|
||||
}
|
||||
|
||||
func netIPPortToAddrPort(ip net.IP, port int) (netip.AddrPort, error) {
|
||||
if port <= 0 || port > 65535 {
|
||||
return netip.AddrPort{}, errors.New("invalid STUN mapped port")
|
||||
}
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
var addr [4]byte
|
||||
copy(addr[:], ip4)
|
||||
return netip.AddrPortFrom(netip.AddrFrom4(addr), uint16(port)), nil
|
||||
}
|
||||
ip16 := ip.To16()
|
||||
if ip16 == nil {
|
||||
return netip.AddrPort{}, errors.New("invalid STUN mapped IP")
|
||||
}
|
||||
var addr [16]byte
|
||||
copy(addr[:], ip16)
|
||||
return netip.AddrPortFrom(netip.AddrFrom16(addr), uint16(port)), nil
|
||||
}
|
||||
|
||||
func candidatePunchAddrs(locals, peers []netip.AddrPort) ([]netip.AddrPort, map[netip.AddrPort]struct{}) {
|
||||
var allow4, allow6 bool
|
||||
for _, local := range locals {
|
||||
if local.Addr().Is4() {
|
||||
allow4 = true
|
||||
} else {
|
||||
allow6 = true
|
||||
}
|
||||
if allow4 && allow6 {
|
||||
break
|
||||
}
|
||||
}
|
||||
var seen = make(map[netip.AddrPort]struct{}, len(peers))
|
||||
var candidates = make([]netip.AddrPort, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
if _, ok := seen[peer]; ok {
|
||||
continue
|
||||
}
|
||||
if peer.IsValid() {
|
||||
if peer.Addr().Is4() {
|
||||
if allow4 {
|
||||
seen[peer] = struct{}{}
|
||||
candidates = append(candidates, peer)
|
||||
}
|
||||
} else {
|
||||
if allow6 {
|
||||
seen[peer] = struct{}{}
|
||||
candidates = append(candidates, peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return candidates, seen
|
||||
}
|
||||
|
||||
func expandSymmetricNATCandidates(candidates []netip.AddrPort, seen map[netip.AddrPort]struct{}) []netip.AddrPort {
|
||||
portsByIP := make(map[netip.Addr][]uint16)
|
||||
for _, addr := range candidates {
|
||||
if addr.Addr().Is4() {
|
||||
portsByIP[addr.Addr()] = append(portsByIP[addr.Addr()], addr.Port())
|
||||
}
|
||||
}
|
||||
for ip, ports := range portsByIP {
|
||||
ports = uniqueSortedPorts(ports)
|
||||
if !predictablePortGroup(ports) {
|
||||
continue
|
||||
}
|
||||
start := int(ports[0])
|
||||
end := int(ports[len(ports)-1]) + symmetricNATExtraPorts
|
||||
if end > 65535 {
|
||||
end = 65535
|
||||
}
|
||||
added := 0
|
||||
for port := start; port <= end && added < symmetricNATMaxPortsPerHost; port++ {
|
||||
addr := netip.AddrPortFrom(ip, uint16(port))
|
||||
if _, ok := seen[addr]; ok {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
candidates = append(candidates, addr)
|
||||
added++
|
||||
}
|
||||
}
|
||||
sortAddrPorts(candidates)
|
||||
return candidates
|
||||
}
|
||||
|
||||
func uniqueSortedPorts(ports []uint16) []uint16 {
|
||||
slices.Sort(ports)
|
||||
out := ports[:0]
|
||||
var last uint16
|
||||
for i, port := range ports {
|
||||
if i > 0 && port == last {
|
||||
continue
|
||||
}
|
||||
out = append(out, port)
|
||||
last = port
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func predictablePortGroup(ports []uint16) bool {
|
||||
if len(ports) < 2 {
|
||||
return false
|
||||
}
|
||||
for i := 1; i < len(ports); i++ {
|
||||
if ports[i]-ports[i-1] > symmetricNATPortGap {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func sortAddrPorts(addrs []netip.AddrPort) {
|
||||
slices.SortFunc(addrs, func(a, b netip.AddrPort) int {
|
||||
return strings.Compare(a.String(), b.String())
|
||||
})
|
||||
}
|
||||
|
||||
func addrPortStrings(addrs []netip.AddrPort) []string {
|
||||
out := make([]string, 0, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
out = append(out, addr.String())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseAddrPorts(addrs []string) ([]netip.AddrPort, error) {
|
||||
out := make([]netip.AddrPort, 0, len(addrs))
|
||||
for _, s := range addrs {
|
||||
addr, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, addr)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -801,7 +801,8 @@ func (b *bandwidthSampler) onPacketAcknowledged(ackTime monotime.Time, packetNum
|
||||
if sentPacketPointer.sentTime.After(sentPacketPointer.lastAckedPacketSentTime) {
|
||||
sendRate = BandwidthFromDelta(
|
||||
sentPacketPointer.sendTimeState.totalBytesSent-sentPacketPointer.totalBytesSentAtLastAckedPacket,
|
||||
sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime))
|
||||
sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime),
|
||||
)
|
||||
}
|
||||
|
||||
var a0 ackPoint
|
||||
@@ -848,7 +849,8 @@ func (b *bandwidthSampler) onAckEventEnd(
|
||||
b.lastSentPacket,
|
||||
b.lastAckedPacket,
|
||||
b.lastAckedPacketAckTime,
|
||||
newlyAckedBytes)
|
||||
newlyAckedBytes,
|
||||
)
|
||||
// If |extra_acked| is zero, i.e. this ack event marks the start of a new ack
|
||||
// aggregation epoch, save LessRecentPoint, which is the last ack point of the
|
||||
// previous epoch, as a A0 candidate.
|
||||
|
||||
@@ -640,7 +640,8 @@ func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, even
|
||||
func (b *bbrSender) PacingRate() Bandwidth {
|
||||
if b.pacingRate == 0 {
|
||||
return Bandwidth(b.highGain * float64(
|
||||
BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt())))
|
||||
BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt()),
|
||||
))
|
||||
}
|
||||
|
||||
return b.pacingRate
|
||||
|
||||
Reference in New Issue
Block a user