From 66a81007379656b803d55ac2c093f0ad2408f2b2 Mon Sep 17 00:00:00 2001 From: LjhAUMEM Date: Fri, 29 May 2026 16:36:45 +0800 Subject: [PATCH] Salamander finalmask: Support `packetSize` (Gecko in Hysteria v2.9.2) (#6198) And some other refinements https://github.com/XTLS/Xray-core/pull/6198#issuecomment-4567438813 Example: https://github.com/XTLS/Xray-core/pull/6198#issue-4522226670 --- infra/conf/transport_internet.go | 16 +- proxy/hysteria/protocol.go | 3 + transport/internet/finalmask/finalmask.go | 109 +++--- .../mkcp/aes128gcm/aes128gcm_test.go | 70 ---- .../finalmask/mkcp/aes128gcm/config.go | 8 +- .../internet/finalmask/mkcp/aes128gcm/conn.go | 37 +- .../finalmask/mkcp/original/config.go | 8 +- .../internet/finalmask/mkcp/original/conn.go | 12 +- .../finalmask/mkcp/original/simple_test.go | 35 -- .../internet/finalmask/salamander/config.go | 18 +- .../finalmask/salamander/config.pb.go | 73 +++- .../finalmask/salamander/config.proto | 5 + .../internet/finalmask/salamander/conn.go | 321 +++++++++++++++++- .../internet/finalmask/salamander/gecko.go | 86 +++++ .../finalmask/salamander/salamander.go | 22 +- .../finalmask/salamander/salamander_test.go | 81 ----- transport/internet/finalmask/udp_test.go | 44 --- 17 files changed, 587 insertions(+), 361 deletions(-) delete mode 100644 transport/internet/finalmask/mkcp/aes128gcm/aes128gcm_test.go delete mode 100644 transport/internet/finalmask/mkcp/original/simple_test.go create mode 100644 transport/internet/finalmask/salamander/gecko.go delete mode 100644 transport/internet/finalmask/salamander/salamander_test.go diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 02f94de0a..69a909226 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -1811,13 +1811,21 @@ func (c *Aes128Gcm) Build() (proto.Message, error) { } type Salamander struct { - Password string `json:"password"` + Password string `json:"password"` + PacketSize *Int32Range `json:"packetSize"` } func (c *Salamander) Build() (proto.Message, error) { - config := &salamander.Config{} - config.Password = c.Password - return config, nil + if c.PacketSize != nil { + return &salamander.GeckoConfig{ + Password: c.Password, + MinPacketSize: c.PacketSize.From, + MaxPacketSize: c.PacketSize.To, + }, nil + } + return &salamander.Config{ + Password: c.Password, + }, nil } type Sudoku struct { diff --git a/proxy/hysteria/protocol.go b/proxy/hysteria/protocol.go index 5434bd00c..7209172e1 100644 --- a/proxy/hysteria/protocol.go +++ b/proxy/hysteria/protocol.go @@ -253,6 +253,9 @@ func FragUDPMessage(m *UDPMessage, maxSize int) []UDPMessage { } fullPayload := m.Data maxPayloadSize := maxSize - m.HeaderSize() + if maxPayloadSize <= 0 { + return nil + } off := 0 fragID := uint8(0) fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up diff --git a/transport/internet/finalmask/finalmask.go b/transport/internet/finalmask/finalmask.go index c732f477c..d7697cfe1 100644 --- a/transport/internet/finalmask/finalmask.go +++ b/transport/internet/finalmask/finalmask.go @@ -3,9 +3,8 @@ package finalmask import ( "context" "net" - "sync" - "github.com/xtls/xray-core/common/bytespool" + "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" ) @@ -105,61 +104,60 @@ type headerSize interface { } type headerManagerConn struct { - sync.Mutex net.PacketConn - sizes []int - conns []net.PacketConn - writeBuf [UDPSize]byte + sizes []int + conns []net.PacketConn } func (c *headerManagerConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - buf := p - if len(buf) < UDPSize { - b := bytespool.Alloc(UDPSize) - b = b[:UDPSize] - defer bytespool.Free(b) - buf = b + b := p + if len(b) < UDPSize { + buf := buf.New() + buf.Resize(0, UDPSize) + b = buf.Bytes() + defer buf.Release() } - n, addr, err = c.PacketConn.ReadFrom(buf) - if n == 0 || err != nil { - return 0, addr, err - } - newBuf := buf[:n] - - sum := 0 - for _, size := range c.sizes { - sum += size - } - - if n < sum { - errors.LogDebug(context.Background(), addr, " mask read err short length") - return 0, addr, nil - } - - for i := range c.conns { - n, _, err = c.conns[i].ReadFrom(newBuf) - if n == 0 || err != nil { - errors.LogDebug(context.Background(), addr, " mask read err ", err) - return 0, addr, nil + for { + n, addr, err = c.PacketConn.ReadFrom(b) + if err != nil { + return n, addr, err } - newBuf = newBuf[c.sizes[i] : n+c.sizes[i]] + b = b[:n] + + sum := 0 + for _, size := range c.sizes { + sum += size + } + + if n < sum { + errors.LogError(context.Background(), "[mask] drop packet from ", addr, " with size ", len(b)) + continue + } + + for i := range c.conns { + n, _, err = c.conns[i].ReadFrom(b) + if err != nil { + errors.LogErrorInner(context.Background(), err, "[mask] drop packet from ", addr, " with size ", len(b)) + break + } + b = b[c.sizes[i] : n+c.sizes[i]] + } + + if err != nil { + continue + } + + return copy(p, b), addr, nil } - - if len(p) < n { - errors.LogDebug(context.Background(), addr, " mask read err short buffer") - return 0, addr, nil - } - - copy(p, buf[sum:sum+n]) - - return n, addr, nil } func (c *headerManagerConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - c.Lock() - defer c.Unlock() + buf := buf.New() + buf.Resize(0, UDPSize) + b := buf.Bytes() + defer buf.Release() sum := 0 for _, size := range c.sizes { @@ -167,24 +165,29 @@ func (c *headerManagerConn) WriteTo(p []byte, addr net.Addr) (n int, err error) } if sum+len(p) > UDPSize { - errors.LogDebug(context.Background(), addr, " mask write err short write") + errors.LogError(context.Background(), "[mask] drop packet to ", addr, " with size ", len(p)) return 0, nil } - n = copy(c.writeBuf[sum:], p) + n = copy(b[sum:], p) for i := len(c.conns) - 1; i >= 0; i-- { - n, err = c.conns[i].WriteTo(c.writeBuf[sum-c.sizes[i]:n+sum], nil) - if n == 0 || err != nil { - errors.LogDebug(context.Background(), addr, " mask write err ", err) + n, err = c.conns[i].WriteTo(b[sum-c.sizes[i]:n+sum], nil) + if err != nil { + errors.LogErrorInner(context.Background(), err, "[mask] drop packet to ", addr, " with size ", len(p)) return 0, nil } sum -= c.sizes[i] } - n, err = c.PacketConn.WriteTo(c.writeBuf[:n], addr) - if n == 0 || err != nil { - return n, err + if n > UDPSize { + errors.LogError(context.Background(), "[mask] drop packet to ", addr, " with size ", len(p)) + return 0, nil + } + + _, err = c.PacketConn.WriteTo(b[:n], addr) + if err != nil { + return 0, err } return len(p), nil diff --git a/transport/internet/finalmask/mkcp/aes128gcm/aes128gcm_test.go b/transport/internet/finalmask/mkcp/aes128gcm/aes128gcm_test.go deleted file mode 100644 index 4806dfc2c..000000000 --- a/transport/internet/finalmask/mkcp/aes128gcm/aes128gcm_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package aes128gcm_test - -import ( - "crypto/rand" - "crypto/sha256" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/xtls/xray-core/common/crypto" -) - -func TestAes128GcmSealInPlace(t *testing.T) { - hashedPsk := sha256.Sum256([]byte("psk")) - aead := crypto.NewAesGcm(hashedPsk[:16]) - - text := []byte("0123456789012") - buf := make([]byte, 8192) - - nonceSize := aead.NonceSize() - nonce := buf[:nonceSize] - rand.Read(nonce) - copy(buf[nonceSize:], text) - plaintext := buf[nonceSize : nonceSize+len(text)] - - sealed := aead.Seal(nil, nonce, plaintext, nil) - - _ = aead.Seal(plaintext[:0], nonce, plaintext, nil) - - assert.Equal(t, sealed, buf[nonceSize:nonceSize+aead.Overhead()+len(text)]) -} - -func encrypted(plain []byte) ([]byte, []byte) { - hashedPsk := sha256.Sum256([]byte("psk")) - aead := crypto.NewAesGcm(hashedPsk[:16]) - - nonce := make([]byte, 12) - rand.Read(nonce) - - return nonce, aead.Seal(nil, nonce, plain, nil) -} - -func TestAes128GcmOpenInPlace(t *testing.T) { - a, b := encrypted([]byte("0123456789012")) - buf := make([]byte, 8192) - copy(buf, a) - copy(buf[len(a):], b) - - hashedPsk := sha256.Sum256([]byte("psk")) - aead := crypto.NewAesGcm(hashedPsk[:16]) - - nonceSize := aead.NonceSize() - nonce := buf[:nonceSize] - ciphertext := buf[nonceSize : nonceSize+len(b)] - - opened, _ := aead.Open(nil, nonce, ciphertext, nil) - _, _ = aead.Open(ciphertext[:0], nonce, ciphertext, nil) - - assert.Equal(t, opened, ciphertext[:len(ciphertext)-aead.Overhead()]) -} - -func TestAes128GcmBounce(t *testing.T) { - hashedPsk := sha256.Sum256([]byte("psk")) - aead := crypto.NewAesGcm(hashedPsk[:16]) - buf := make([]byte, aead.NonceSize()+aead.Overhead()) - for i := 0; i < 1000; i++ { - _, _ = rand.Read(buf) - _, err := aead.Open(buf[aead.NonceSize():aead.NonceSize()], buf[:aead.NonceSize()], buf[aead.NonceSize():], nil) - assert.NotEqual(t, err, nil) - } -} diff --git a/transport/internet/finalmask/mkcp/aes128gcm/config.go b/transport/internet/finalmask/mkcp/aes128gcm/config.go index a7160e2b0..70d87ae00 100644 --- a/transport/internet/finalmask/mkcp/aes128gcm/config.go +++ b/transport/internet/finalmask/mkcp/aes128gcm/config.go @@ -4,8 +4,9 @@ import ( "net" ) -func (c *Config) UDP() { -} +func (c *Config) UDP() {} + +func (c *Config) HeaderConn() {} func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { return NewConnClient(c, raw) @@ -14,6 +15,3 @@ func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { return NewConnServer(c, raw) } - -func (c *Config) HeaderConn() { -} diff --git a/transport/internet/finalmask/mkcp/aes128gcm/conn.go b/transport/internet/finalmask/mkcp/aes128gcm/conn.go index 055803af2..6d1012a38 100644 --- a/transport/internet/finalmask/mkcp/aes128gcm/conn.go +++ b/transport/internet/finalmask/mkcp/aes128gcm/conn.go @@ -8,8 +8,6 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/crypto" - "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/transport/internet/finalmask" ) type aes128gcmConn struct { @@ -19,13 +17,10 @@ type aes128gcmConn struct { func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { hashedPsk := sha256.Sum256([]byte(c.Password)) - - conn := &aes128gcmConn{ + return &aes128gcmConn{ PacketConn: raw, aead: crypto.NewAesGcm(hashedPsk[:16]), - } - - return conn, nil + }, nil } func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { @@ -37,31 +32,19 @@ func (c *aes128gcmConn) Size() int { } func (c *aes128gcmConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if len(p) < c.aead.NonceSize()+c.aead.Overhead() { - return 0, addr, errors.New("aead short lenth") - } - nonceSize := c.aead.NonceSize() - nonce := p[:nonceSize] - ciphertext := p[nonceSize:] - _, err = c.aead.Open(ciphertext[:0], nonce, ciphertext, nil) + overhead := c.aead.Overhead() + _, err = c.aead.Open(p[nonceSize:nonceSize], p[:nonceSize], p[nonceSize:], nil) if err != nil { - return 0, addr, errors.New("aead open").Base(err) + return 0, nil, err } - - return len(p) - c.aead.NonceSize() - c.aead.Overhead(), addr, nil + return len(p) - nonceSize - overhead, nil, nil } func (c *aes128gcmConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if c.aead.Overhead()+len(p) > finalmask.UDPSize { - return 0, errors.New("aead short write") - } - nonceSize := c.aead.NonceSize() - nonce := p[:nonceSize] - common.Must2(rand.Read(nonce)) - plaintext := p[nonceSize:] - _ = c.aead.Seal(plaintext[:0], nonce, plaintext, nil) - - return len(p) + c.aead.Overhead(), nil + overhead := c.aead.Overhead() + common.Must2(rand.Read(p[:nonceSize])) + _ = c.aead.Seal(p[nonceSize:nonceSize], p[:nonceSize], p[nonceSize:], nil) + return len(p) + overhead, nil } diff --git a/transport/internet/finalmask/mkcp/original/config.go b/transport/internet/finalmask/mkcp/original/config.go index d18b13918..b6b42c73d 100644 --- a/transport/internet/finalmask/mkcp/original/config.go +++ b/transport/internet/finalmask/mkcp/original/config.go @@ -4,8 +4,9 @@ import ( "net" ) -func (c *Config) UDP() { -} +func (c *Config) UDP() {} + +func (c *Config) HeaderConn() {} func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { return NewConnClient(c, raw) @@ -14,6 +15,3 @@ func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { return NewConnServer(c, raw) } - -func (c *Config) HeaderConn() { -} diff --git a/transport/internet/finalmask/mkcp/original/conn.go b/transport/internet/finalmask/mkcp/original/conn.go index 15abe99be..5fbe598df 100644 --- a/transport/internet/finalmask/mkcp/original/conn.go +++ b/transport/internet/finalmask/mkcp/original/conn.go @@ -77,12 +77,10 @@ type simpleConn struct { } func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { - conn := &simpleConn{ + return &simpleConn{ PacketConn: raw, aead: &simple{}, - } - - return conn, nil + }, nil } func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { @@ -96,14 +94,12 @@ func (c *simpleConn) Size() int { func (c *simpleConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { _, err = c.aead.Open(p[:0], nil, p, nil) if err != nil { - return 0, addr, errors.New("aead open").Base(err) + return 0, nil, err } - - return len(p) - c.aead.Overhead(), addr, nil + return len(p) - c.aead.Overhead(), nil, nil } func (c *simpleConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { _ = c.aead.Seal(p[:0], nil, p[c.aead.Overhead():], nil) - return len(p), nil } diff --git a/transport/internet/finalmask/mkcp/original/simple_test.go b/transport/internet/finalmask/mkcp/original/simple_test.go deleted file mode 100644 index be9d68398..000000000 --- a/transport/internet/finalmask/mkcp/original/simple_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package original_test - -import ( - "crypto/rand" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original" -) - -func TestSimpleSealInPlace(t *testing.T) { - aead := original.NewSimple() - - text := []byte("0123456789012") - buf := make([]byte, 8192) - - copy(buf[aead.Overhead():], text) - plaintext := buf[aead.Overhead() : aead.Overhead()+len(text)] - - sealed := aead.Seal(nil, nil, plaintext, nil) - - _ = aead.Seal(buf[:0], nil, plaintext, nil) - - assert.Equal(t, sealed, buf[:aead.Overhead()+len(text)]) -} - -func TestOriginalBounce(t *testing.T) { - aead := original.NewSimple() - buf := make([]byte, aead.NonceSize()+aead.Overhead()) - for i := 0; i < 1000; i++ { - _, _ = rand.Read(buf) - _, err := aead.Open(buf[:0], nil, buf, nil) - assert.NotEqual(t, err, nil) - } -} diff --git a/transport/internet/finalmask/salamander/config.go b/transport/internet/finalmask/salamander/config.go index 8df1285d3..192c466b7 100644 --- a/transport/internet/finalmask/salamander/config.go +++ b/transport/internet/finalmask/salamander/config.go @@ -4,16 +4,24 @@ import ( "net" ) -func (c *Config) UDP() { -} +func (c *Config) UDP() {} + +func (c *Config) HeaderConn() {} func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { - return NewConnClient(c, raw) + return NewSalamanderConnClient(c, raw) } func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { - return NewConnServer(c, raw) + return NewSalamanderConnServer(c, raw) } -func (c *Config) HeaderConn() { +func (c *GeckoConfig) UDP() {} + +func (c *GeckoConfig) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewGeckoConnClient(c, raw) +} + +func (c *GeckoConfig) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + return NewGeckoConnServer(c, raw) } diff --git a/transport/internet/finalmask/salamander/config.pb.go b/transport/internet/finalmask/salamander/config.pb.go index 949df60bd..4f4167be1 100644 --- a/transport/internet/finalmask/salamander/config.pb.go +++ b/transport/internet/finalmask/salamander/config.pb.go @@ -65,13 +65,77 @@ func (x *Config) GetPassword() string { return "" } +type GeckoConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` + MinPacketSize int32 `protobuf:"varint,2,opt,name=MinPacketSize,proto3" json:"MinPacketSize,omitempty"` + MaxPacketSize int32 `protobuf:"varint,3,opt,name=MaxPacketSize,proto3" json:"MaxPacketSize,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GeckoConfig) Reset() { + *x = GeckoConfig{} + mi := &file_transport_internet_finalmask_salamander_config_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GeckoConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GeckoConfig) ProtoMessage() {} + +func (x *GeckoConfig) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_salamander_config_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GeckoConfig.ProtoReflect.Descriptor instead. +func (*GeckoConfig) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_salamander_config_proto_rawDescGZIP(), []int{1} +} + +func (x *GeckoConfig) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +func (x *GeckoConfig) GetMinPacketSize() int32 { + if x != nil { + return x.MinPacketSize + } + return 0 +} + +func (x *GeckoConfig) GetMaxPacketSize() int32 { + if x != nil { + return x.MaxPacketSize + } + return 0 +} + var File_transport_internet_finalmask_salamander_config_proto protoreflect.FileDescriptor const file_transport_internet_finalmask_salamander_config_proto_rawDesc = "" + "\n" + "4transport/internet/finalmask/salamander/config.proto\x12,xray.transport.internet.finalmask.salamander\"$\n" + "\x06Config\x12\x1a\n" + - "\bpassword\x18\x01 \x01(\tR\bpasswordB\xa6\x01\n" + + "\bpassword\x18\x01 \x01(\tR\bpassword\"u\n" + + "\vGeckoConfig\x12\x1a\n" + + "\bpassword\x18\x01 \x01(\tR\bpassword\x12$\n" + + "\rMinPacketSize\x18\x02 \x01(\x05R\rMinPacketSize\x12$\n" + + "\rMaxPacketSize\x18\x03 \x01(\x05R\rMaxPacketSizeB\xa6\x01\n" + "0com.xray.transport.internet.finalmask.salamanderP\x01ZAgithub.com/xtls/xray-core/transport/internet/finalmask/salamander\xaa\x02,Xray.Transport.Internet.Finalmask.Salamanderb\x06proto3" var ( @@ -86,9 +150,10 @@ func file_transport_internet_finalmask_salamander_config_proto_rawDescGZIP() []b return file_transport_internet_finalmask_salamander_config_proto_rawDescData } -var file_transport_internet_finalmask_salamander_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_salamander_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_transport_internet_finalmask_salamander_config_proto_goTypes = []any{ - (*Config)(nil), // 0: xray.transport.internet.finalmask.salamander.Config + (*Config)(nil), // 0: xray.transport.internet.finalmask.salamander.Config + (*GeckoConfig)(nil), // 1: xray.transport.internet.finalmask.salamander.GeckoConfig } var file_transport_internet_finalmask_salamander_config_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type @@ -109,7 +174,7 @@ func file_transport_internet_finalmask_salamander_config_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_salamander_config_proto_rawDesc), len(file_transport_internet_finalmask_salamander_config_proto_rawDesc)), NumEnums: 0, - NumMessages: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, diff --git a/transport/internet/finalmask/salamander/config.proto b/transport/internet/finalmask/salamander/config.proto index 34bd4cefe..196391035 100644 --- a/transport/internet/finalmask/salamander/config.proto +++ b/transport/internet/finalmask/salamander/config.proto @@ -10,3 +10,8 @@ message Config { string password = 1; } +message GeckoConfig { + string password = 1; + int32 MinPacketSize = 2; + int32 MaxPacketSize = 3; +} \ No newline at end of file diff --git a/transport/internet/finalmask/salamander/conn.go b/transport/internet/finalmask/salamander/conn.go index bfb4934ac..c87150998 100644 --- a/transport/internet/finalmask/salamander/conn.go +++ b/transport/internet/finalmask/salamander/conn.go @@ -1,9 +1,16 @@ package salamander import ( + "crypto/rand" + "encoding/binary" + "errors" "net" + "sync" + "sync/atomic" + "time" - "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/buf" + "github.com/xtls/xray-core/transport/internet/finalmask" ) type salamanderConn struct { @@ -11,22 +18,19 @@ type salamanderConn struct { obfs *SalamanderObfuscator } -func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { +func NewSalamanderConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { ob, err := NewSalamanderObfuscator([]byte(c.Password)) if err != nil { - return nil, errors.New("salamander err").Base(err) + return nil, err } - - conn := &salamanderConn{ + return &salamanderConn{ PacketConn: raw, obfs: ob, - } - - return conn, nil + }, nil } -func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { - return NewConnClient(c, raw) +func NewSalamanderConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { + return NewSalamanderConnClient(c, raw) } func (c *salamanderConn) Size() int { @@ -35,12 +39,303 @@ func (c *salamanderConn) Size() int { func (c *salamanderConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { c.obfs.Deobfuscate(p, p[smSaltLen:]) - - return len(p) - smSaltLen, addr, nil + return len(p) - smSaltLen, nil, nil } func (c *salamanderConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { c.obfs.Obfuscate(p[smSaltLen:], p) - return len(p), nil } + +const ( + geckoReassemblyTTL = 8 * time.Second + geckoMaxReassembly = 4096 + geckoMaxPerSource = 8 + + geckoBufferSize = 2048 + + geckoDefaultMinPacket = 512 + geckoDefaultMaxPacket = 1200 +) + +type reassemblyKey struct { + addr string + msgID uint8 +} + +type reassemblyEntry struct { + chunks [][]byte + received int + total uint8 + deadline time.Time +} + +type geckoConn struct { + net.PacketConn + obfs *SalamanderObfuscator + minPkt, maxPkt int + + msgID atomic.Uint32 + + mu sync.Mutex + reassembly map[reassemblyKey]*reassemblyEntry + perSource map[string]int + + closeCh chan struct{} + closeOnce sync.Once +} + +func NewGeckoConnClient(c *GeckoConfig, raw net.PacketConn) (net.PacketConn, error) { + ob, err := NewSalamanderObfuscator([]byte(c.Password)) + if err != nil { + return nil, err + } + minPkt, maxPkt := c.MinPacketSize, c.MaxPacketSize + if minPkt == 0 { + minPkt = geckoDefaultMinPacket + } + if maxPkt == 0 { + maxPkt = geckoDefaultMaxPacket + } + if minPkt <= 0 || minPkt > maxPkt || maxPkt > geckoBufferSize { + return nil, errors.New("gecko: invalid min/max packet size") + } + g := &geckoConn{ + PacketConn: raw, + obfs: ob, + minPkt: int(minPkt), + maxPkt: int(maxPkt), + reassembly: make(map[reassemblyKey]*reassemblyEntry), + perSource: make(map[string]int), + closeCh: make(chan struct{}), + } + go g.gcLoop() + return g, nil +} + +func NewGeckoConnServer(c *GeckoConfig, raw net.PacketConn) (net.PacketConn, error) { + return NewGeckoConnClient(c, raw) +} + +func (c *geckoConn) readObfs(p []byte) (n int, addr net.Addr, err error) { + for { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return n, addr, err + } + if n < smSaltLen { + continue + } + c.obfs.Deobfuscate(p[:n], p) + return n - smSaltLen, addr, nil + } +} + +func (c *geckoConn) writeObfs(p []byte, addr net.Addr) (n int, err error) { + b := buf.New() + b.Resize(0, int32(len(p)+smSaltLen)) + defer b.Release() + c.obfs.Obfuscate(p, b.Bytes()) + return c.PacketConn.WriteTo(b.Bytes(), addr) +} + +func (g *geckoConn) WriteTo(p []byte, addr net.Addr) (int, error) { + if len(p) == 0 { + return 0, nil + } + if p[0]&0x80 != 0 { + // QUIC long header, do fragmentation. + return g.writeFragmented(p, addr) + } + // QUIC short header (data), pass through. + return g.writeObfs(p, addr) +} + +func (g *geckoConn) writeFragmented(p []byte, addr net.Addr) (int, error) { + chunks := randomFragmentChunks() + chunkSize := len(p) / chunks + msgID := uint8(g.msgID.Add(1)) + for i := range chunks { + start := i * chunkSize + end := len(p) + if i < chunks-1 { + end = start + chunkSize + } + chunk := p[start:end] + padLen := g.randomPadLen(len(chunk)) + buf := make([]byte, geckoHeaderSize+int(padLen)+len(chunk)) + n, err := encodeFrame(frameHeader{ + padLen: padLen, + msgID: msgID, + chunkIdx: uint8(i), + totalChunks: uint8(chunks), + }, chunk, buf) + if err != nil { + return 0, err + } + if _, err := g.writeObfs(buf[:n], addr); err != nil { + return 0, err + } + } + return len(p), nil +} + +func (g *geckoConn) randomPadLen(chunkLen int) uint16 { + base := smSaltLen + geckoHeaderSize + chunkLen + lo := max(g.minPkt, base) + if lo > g.maxPkt { + return 0 + } + return uint16(lo - base + randIntn(g.maxPkt-lo+1)) +} + +func randomFragmentChunks() int { + return geckoMinFragmentChunks + randIntn(geckoMaxFragmentChunks-geckoMinFragmentChunks+1) +} + +func randIntn(n int) int { + if n <= 1 { + return 0 + } + var b [4]byte + _, _ = rand.Read(b[:]) + return int(binary.BigEndian.Uint32(b[:]) % uint32(n)) +} + +func (g *geckoConn) ReadFrom(p []byte) (int, net.Addr, error) { + b := buf.New() + b.Resize(0, finalmask.UDPSize) + buf := b.Bytes() + defer b.Release() + for { + n, addr, err := g.readObfs(buf) + if err != nil { + return 0, addr, err + } + if n <= 0 { + continue + } + // Top bit set → Gecko fragment frame; clear → short-header packet + // or garbage, passed through for QUIC to handle. + if buf[0]&0x80 == 0 { + return copy(p, buf[:n]), addr, nil + } + h, payload, decErr := decodeFrame(buf[:n]) + if decErr != nil { + // Malformed frame; drop silently. + continue + } + out, ready := g.acceptChunk(addr, h, payload) + if !ready { + continue + } + return copy(p, out), addr, nil + } +} + +func (g *geckoConn) acceptChunk(addr net.Addr, h frameHeader, payload []byte) ([]byte, bool) { + key := reassemblyKey{addr: addr.String(), msgID: h.msgID} + + g.mu.Lock() + defer g.mu.Unlock() + + e, exists := g.reassembly[key] + if !exists { + // Per-source cap. + if g.perSource[key.addr] >= geckoMaxPerSource { + return nil, false + } + // Global cap with eviction. + if len(g.reassembly) >= geckoMaxReassembly { + g.evictOldestLocked() + } + e = &reassemblyEntry{ + chunks: make([][]byte, h.totalChunks), + total: h.totalChunks, + deadline: time.Now().Add(geckoReassemblyTTL), + } + g.reassembly[key] = e + g.perSource[key.addr]++ + } else if e.total != h.totalChunks { + // Inconsistent chunk count; drop. + return nil, false + } + if int(h.chunkIdx) >= len(e.chunks) || e.chunks[h.chunkIdx] != nil { + // Bad index or duplicate; drop. + return nil, false + } + cp := make([]byte, len(payload)) + copy(cp, payload) + e.chunks[h.chunkIdx] = cp + e.received++ + if e.received < int(e.total) { + return nil, false + } + + total := 0 + for _, c := range e.chunks { + total += len(c) + } + out := make([]byte, total) + off := 0 + for _, c := range e.chunks { + off += copy(out[off:], c) + } + g.dropEntryLocked(key) + return out, true +} + +func (g *geckoConn) gcLoop() { + t := time.NewTicker(geckoReassemblyTTL / 2) + defer t.Stop() + for { + select { + case <-g.closeCh: + return + case now := <-t.C: + g.gcExpired(now) + } + } +} + +func (g *geckoConn) gcExpired(now time.Time) { + g.mu.Lock() + defer g.mu.Unlock() + for k, e := range g.reassembly { + if now.After(e.deadline) { + g.dropEntryLocked(k) + } + } +} + +func (g *geckoConn) dropEntryLocked(k reassemblyKey) { + if _, ok := g.reassembly[k]; !ok { + return + } + delete(g.reassembly, k) + g.perSource[k.addr]-- + if g.perSource[k.addr] <= 0 { + delete(g.perSource, k.addr) + } +} + +func (g *geckoConn) evictOldestLocked() { + var oldestKey reassemblyKey + var oldestDeadline time.Time + first := true + for k, e := range g.reassembly { + if first || e.deadline.Before(oldestDeadline) { + oldestKey = k + oldestDeadline = e.deadline + first = false + } + } + if !first { + g.dropEntryLocked(oldestKey) + } +} + +func (g *geckoConn) Close() error { + g.closeOnce.Do(func() { close(g.closeCh) }) + return g.PacketConn.Close() +} diff --git a/transport/internet/finalmask/salamander/gecko.go b/transport/internet/finalmask/salamander/gecko.go new file mode 100644 index 000000000..e4db7049c --- /dev/null +++ b/transport/internet/finalmask/salamander/gecko.go @@ -0,0 +1,86 @@ +package salamander + +import ( + "crypto/rand" + "encoding/binary" + "errors" +) + +const ( + geckoFlagFragment = 0x80 + geckoHeaderSize = 5 + + geckoMinFragmentChunks = 2 + geckoMaxFragmentChunks = 8 +) + +var ( + errFrameTruncated = errors.New("gecko frame truncated") + errFrameInvalid = errors.New("gecko frame invalid") +) + +// frameHeader is a Gecko fragment frame header. +// Wire layout (after Salamander decryption): +// +// byte 0: 0x80 (fragment marker; low 7 bits reserved) +// byte 1: msgID +// byte 2: chunkIdx:4 | totalChunks:4 +// byte 3-4: padLen (uint16, big-endian) +// then padLen random padding bytes, then the chunk payload +type frameHeader struct { + padLen uint16 + msgID uint8 + chunkIdx uint8 // < totalChunks + totalChunks uint8 // [2, 8] +} + +// encodeFrame writes a frame into out, filling the padding region with random +// bytes. out must be at least geckoHeaderSize + h.padLen + len(payload) long. +func encodeFrame(h frameHeader, payload, out []byte) (int, error) { + if h.totalChunks < geckoMinFragmentChunks || h.totalChunks > geckoMaxFragmentChunks { + return 0, errFrameInvalid + } + if h.chunkIdx >= h.totalChunks { + return 0, errFrameInvalid + } + needed := geckoHeaderSize + int(h.padLen) + len(payload) + if len(out) < needed { + return 0, errFrameTruncated + } + out[0] = geckoFlagFragment + out[1] = h.msgID + out[2] = h.chunkIdx<<4 | h.totalChunks&0x0f + binary.BigEndian.PutUint16(out[3:5], h.padLen) + if _, err := rand.Read(out[geckoHeaderSize : geckoHeaderSize+int(h.padLen)]); err != nil { + return 0, err + } + copy(out[geckoHeaderSize+int(h.padLen):], payload) + return needed, nil +} + +// decodeFrame parses a frame from in. The returned payload is a sub-slice of +// in (zero-copy) covering the bytes after the header and padding. +func decodeFrame(in []byte) (frameHeader, []byte, error) { + if len(in) < geckoHeaderSize { + return frameHeader{}, nil, errFrameTruncated + } + if in[0]&geckoFlagFragment == 0 { + return frameHeader{}, nil, errFrameInvalid + } + h := frameHeader{ + msgID: in[1], + chunkIdx: in[2] >> 4, + totalChunks: in[2] & 0x0f, + padLen: binary.BigEndian.Uint16(in[3:5]), + } + if h.totalChunks < geckoMinFragmentChunks || h.totalChunks > geckoMaxFragmentChunks { + return frameHeader{}, nil, errFrameInvalid + } + if h.chunkIdx >= h.totalChunks { + return frameHeader{}, nil, errFrameInvalid + } + if geckoHeaderSize+int(h.padLen) > len(in) { + return frameHeader{}, nil, errFrameTruncated + } + return h, in[geckoHeaderSize+int(h.padLen):], nil +} diff --git a/transport/internet/finalmask/salamander/salamander.go b/transport/internet/finalmask/salamander/salamander.go index 86d92dcdf..d81863aa9 100644 --- a/transport/internet/finalmask/salamander/salamander.go +++ b/transport/internet/finalmask/salamander/salamander.go @@ -24,16 +24,21 @@ type SalamanderObfuscator struct { PSK []byte RandSrc *rand.Rand - lk sync.Mutex + lk sync.Mutex + keyInput []byte } func NewSalamanderObfuscator(psk []byte) (*SalamanderObfuscator, error) { if len(psk) < smPSKMinLen { return nil, ErrPSKTooShort } + pskCopy := append([]byte(nil), psk...) + keyInput := make([]byte, len(pskCopy)+smSaltLen) + copy(keyInput, pskCopy) return &SalamanderObfuscator{ - PSK: psk, - RandSrc: rand.New(rand.NewSource(time.Now().UnixNano())), + PSK: pskCopy, + RandSrc: rand.New(rand.NewSource(time.Now().UnixNano())), + keyInput: keyInput, }, nil } @@ -44,8 +49,8 @@ func (o *SalamanderObfuscator) Obfuscate(in, out []byte) int { } o.lk.Lock() _, _ = o.RandSrc.Read(out[:smSaltLen]) + key := o.keyLocked(out[:smSaltLen]) o.lk.Unlock() - key := o.key(out[:smSaltLen]) for i, c := range in { out[i+smSaltLen] = c ^ key[i%smKeyLen] } @@ -57,13 +62,16 @@ func (o *SalamanderObfuscator) Deobfuscate(in, out []byte) int { if outLen <= 0 || len(out) < outLen { return 0 } - key := o.key(in[:smSaltLen]) + o.lk.Lock() + key := o.keyLocked(in[:smSaltLen]) + o.lk.Unlock() for i, c := range in[smSaltLen:] { out[i] = c ^ key[i%smKeyLen] } return outLen } -func (o *SalamanderObfuscator) key(salt []byte) [smKeyLen]byte { - return blake2b.Sum256(append(o.PSK, salt...)) +func (o *SalamanderObfuscator) keyLocked(salt []byte) [smKeyLen]byte { + copy(o.keyInput[len(o.PSK):], salt[:smSaltLen]) + return blake2b.Sum256(o.keyInput) } diff --git a/transport/internet/finalmask/salamander/salamander_test.go b/transport/internet/finalmask/salamander/salamander_test.go deleted file mode 100644 index ffd508218..000000000 --- a/transport/internet/finalmask/salamander/salamander_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package salamander_test - -import ( - "crypto/rand" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/xtls/xray-core/transport/internet/finalmask/salamander" -) - -const ( - smSaltLen = 8 -) - -func BenchmarkSalamanderObfuscator_Obfuscate(b *testing.B) { - o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) - in := make([]byte, 1200) - _, _ = rand.Read(in) - out := make([]byte, 2048) - b.ResetTimer() - for i := 0; i < b.N; i++ { - o.Obfuscate(in, out) - } -} - -func BenchmarkSalamanderObfuscator_Deobfuscate(b *testing.B) { - o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) - in := make([]byte, 1200) - _, _ = rand.Read(in) - out := make([]byte, 2048) - b.ResetTimer() - for i := 0; i < b.N; i++ { - o.Deobfuscate(in, out) - } -} - -func TestSalamanderObfuscator(t *testing.T) { - o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) - in := make([]byte, 1200) - oOut := make([]byte, 2048) - dOut := make([]byte, 2048) - for i := 0; i < 1000; i++ { - _, _ = rand.Read(in) - n := o.Obfuscate(in, oOut) - assert.Equal(t, len(in)+smSaltLen, n) - n = o.Deobfuscate(oOut[:n], dOut) - assert.Equal(t, len(in), n) - assert.Equal(t, in, dOut[:n]) - } -} - -func TestSalamanderInPlace(t *testing.T) { - o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) - - in := make([]byte, 1200) - out := make([]byte, 2048) - _, _ = rand.Read(in) - o.Obfuscate(in, out) - - out2 := make([]byte, 2048) - copy(out2[smSaltLen:], in) - o.Obfuscate(out2[smSaltLen:], out2) - - dOut := make([]byte, 2048) - o.Deobfuscate(out, dOut) - - o.Deobfuscate(out2, out2) - - assert.Equal(t, in, dOut[:1200]) - assert.Equal(t, in, out2[:1200]) -} - -func TestSalamanderBounce(t *testing.T) { - o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) - buf := make([]byte, 8) - for i := 0; i < 1000; i++ { - _, _ = rand.Read(buf) - n := o.Deobfuscate(buf, buf) - assert.Equal(t, 0, n) - } -} diff --git a/transport/internet/finalmask/udp_test.go b/transport/internet/finalmask/udp_test.go index 0c0641323..8759f9fbc 100644 --- a/transport/internet/finalmask/udp_test.go +++ b/transport/internet/finalmask/udp_test.go @@ -462,50 +462,6 @@ func TestUDPcustomStaticHeaderWireShape(t *testing.T) { } } -func TestUDPcustomServerRejectsMismatchedStaticHeader(t *testing.T) { - cfg := &custom.UDPConfig{ - Client: []*custom.UDPItem{ - {Packet: []byte{0x01, 0x02}}, - }, - Server: []*custom.UDPItem{ - {Packet: []byte{0x03}}, - }, - } - maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{cfg}) - - clientRaw, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer clientRaw.Close() - - serverRaw, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer serverRaw.Close() - - server, err := maskManager.WrapPacketConnServer(serverRaw) - if err != nil { - t.Fatal(err) - } - - _ = server.SetDeadline(time.Now().Add(200 * time.Millisecond)) - - if _, err := clientRaw.WriteTo([]byte{0x09, 0x09, 'b', 'a', 'd'}, server.LocalAddr()); err != nil { - t.Fatal(err) - } - - buf := make([]byte, 128) - n, _, err := server.ReadFrom(buf) - if n != 0 { - t.Fatalf("expected no payload on mismatched header, got %d bytes", n) - } - if err != nil { - t.Fatalf("expected mismatch to be dropped without surfaced error, got %v", err) - } -} - func TestUDPcustomStandaloneClientSendsDetachedHandshakeBeforePayload(t *testing.T) { _, serverRaw, client, _ := newUDPClientServerPair(t, newStandaloneEchoUDPConfig())