WireGuard inbound: Support dynamic peer management (#6360)

https://github.com/XTLS/Xray-core/pull/6360#issuecomment-4780311547

Closes https://github.com/XTLS/Xray-core/issues/6314

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: LjhAUMEM <llnu14702@gmail.com>
This commit is contained in:
bitwiresys
2026-06-27 14:41:22 +03:00
committed by GitHub
parent f496437b84
commit 345c76f9a8
14 changed files with 280 additions and 114 deletions
+4 -13
View File
@@ -24,8 +24,7 @@ const (
type ClientConfig struct {
state protoimpl.MessageState `protogen:"open.v1"`
Version int32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"`
Server *protocol.ServerEndpoint `protobuf:"bytes,2,opt,name=server,proto3" json:"server,omitempty"`
Server *protocol.ServerEndpoint `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -60,13 +59,6 @@ func (*ClientConfig) Descriptor() ([]byte, []int) {
return file_proxy_hysteria_config_proto_rawDescGZIP(), []int{0}
}
func (x *ClientConfig) GetVersion() int32 {
if x != nil {
return x.Version
}
return 0
}
func (x *ClientConfig) GetServer() *protocol.ServerEndpoint {
if x != nil {
return x.Server
@@ -122,10 +114,9 @@ var File_proxy_hysteria_config_proto protoreflect.FileDescriptor
const file_proxy_hysteria_config_proto_rawDesc = "" +
"\n" +
"\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\x1a\x1acommon/protocol/user.proto\"f\n" +
"\fClientConfig\x12\x18\n" +
"\aversion\x18\x01 \x01(\x05R\aversion\x12<\n" +
"\x06server\x18\x02 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06server\"@\n" +
"\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\x1a\x1acommon/protocol/user.proto\"L\n" +
"\fClientConfig\x12<\n" +
"\x06server\x18\x01 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06server\"@\n" +
"\fServerConfig\x120\n" +
"\x05users\x18\x01 \x03(\v2\x1a.xray.common.protocol.UserR\x05usersB[\n" +
"\x17com.xray.proxy.hysteriaP\x01Z(github.com/xtls/xray-core/proxy/hysteria\xaa\x02\x13Xray.Proxy.Hysteriab\x06proto3"
+1 -2
View File
@@ -10,8 +10,7 @@ import "common/protocol/server_spec.proto";
import "common/protocol/user.proto";
message ClientConfig {
int32 version = 1;
xray.common.protocol.ServerEndpoint server = 2;
xray.common.protocol.ServerEndpoint server = 1;
}
message ServerConfig {
+59
View File
@@ -1 +1,60 @@
package wireguard
import (
"encoding/hex"
"net/netip"
"github.com/xtls/xray-core/common/protocol"
"google.golang.org/protobuf/proto"
)
func (p *PeerConfig) AsAccount() (protocol.Account, error) {
pub, err := ParseKey(p.PublicKey)
if err != nil {
return nil, err
}
allowedIPs := make([]netip.Prefix, 0, len(p.AllowedIps))
for i := range p.AllowedIps {
p, err := netip.ParsePrefix(p.AllowedIps[i])
if err != nil {
return nil, err
}
allowedIPs = append(allowedIPs, p)
}
return &MemoryAccount{
Pub: *pub,
AllowedIPs: allowedIPs,
PreSharedKey: p.PreSharedKey,
KeepAlive: p.KeepAlive,
}, nil
}
type MemoryAccount struct {
Pub [32]byte
AllowedIPs []netip.Prefix
PreSharedKey string
KeepAlive string
}
func (a *MemoryAccount) Equals(other protocol.Account) bool {
if b, ok := other.(*MemoryAccount); ok {
return a.Pub == b.Pub
}
return false
}
func (a *MemoryAccount) ToProto() proto.Message {
allowedIPs := make([]string, 0, len(a.AllowedIPs))
for i := range a.AllowedIPs {
allowedIPs = append(allowedIPs, a.AllowedIPs[i].String())
}
return &PeerConfig{
PublicKey: hex.EncodeToString(a.Pub[:]),
AllowedIps: allowedIPs,
PreSharedKey: a.PreSharedKey,
KeepAlive: a.KeepAlive,
}
}
+21 -9
View File
@@ -7,6 +7,7 @@
package wireguard
import (
protocol "github.com/xtls/xray-core/common/protocol"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
@@ -157,6 +158,7 @@ type DeviceConfig struct {
SecretKey string `protobuf:"bytes,1,opt,name=secret_key,json=secretKey,proto3" json:"secret_key,omitempty"`
Endpoint []string `protobuf:"bytes,2,rep,name=endpoint,proto3" json:"endpoint,omitempty"`
Peers []*PeerConfig `protobuf:"bytes,3,rep,name=peers,proto3" json:"peers,omitempty"`
Users []*protocol.User `protobuf:"bytes,5,rep,name=users,proto3" json:"users,omitempty"`
Mtu int32 `protobuf:"varint,4,opt,name=mtu,proto3" json:"mtu,omitempty"`
Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"`
DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"`
@@ -217,6 +219,13 @@ func (x *DeviceConfig) GetPeers() []*PeerConfig {
return nil
}
func (x *DeviceConfig) GetUsers() []*protocol.User {
if x != nil {
return x.Users
}
return nil
}
func (x *DeviceConfig) GetMtu() int32 {
if x != nil {
return x.Mtu
@@ -256,7 +265,7 @@ var File_proxy_wireguard_config_proto protoreflect.FileDescriptor
const file_proxy_wireguard_config_proto_rawDesc = "" +
"\n" +
"\x1cproxy/wireguard/config.proto\x12\x14xray.proxy.wireguard\"\xad\x01\n" +
"\x1cproxy/wireguard/config.proto\x12\x14xray.proxy.wireguard\x1a\x1acommon/protocol/user.proto\"\xad\x01\n" +
"\n" +
"PeerConfig\x12\x1d\n" +
"\n" +
@@ -266,12 +275,13 @@ const file_proxy_wireguard_config_proto_rawDesc = "" +
"\n" +
"keep_alive\x18\x04 \x01(\tR\tkeepAlive\x12\x1f\n" +
"\vallowed_ips\x18\x05 \x03(\tR\n" +
"allowedIps\"\xaa\x03\n" +
"allowedIps\"\xdc\x03\n" +
"\fDeviceConfig\x12\x1d\n" +
"\n" +
"secret_key\x18\x01 \x01(\tR\tsecretKey\x12\x1a\n" +
"\bendpoint\x18\x02 \x03(\tR\bendpoint\x126\n" +
"\x05peers\x18\x03 \x03(\v2 .xray.proxy.wireguard.PeerConfigR\x05peers\x12\x10\n" +
"\x05peers\x18\x03 \x03(\v2 .xray.proxy.wireguard.PeerConfigR\x05peers\x120\n" +
"\x05users\x18\x05 \x03(\v2\x1a.xray.common.protocol.UserR\x05users\x12\x10\n" +
"\x03mtu\x18\x04 \x01(\x05R\x03mtu\x12\x1a\n" +
"\breserved\x18\x06 \x01(\fR\breserved\x12Z\n" +
"\x0fdomain_strategy\x18\a \x01(\x0e21.xray.proxy.wireguard.DeviceConfig.DomainStrategyR\x0edomainStrategy\x12\x1b\n" +
@@ -305,15 +315,17 @@ var file_proxy_wireguard_config_proto_goTypes = []any{
(DeviceConfig_DomainStrategy)(0), // 0: xray.proxy.wireguard.DeviceConfig.DomainStrategy
(*PeerConfig)(nil), // 1: xray.proxy.wireguard.PeerConfig
(*DeviceConfig)(nil), // 2: xray.proxy.wireguard.DeviceConfig
(*protocol.User)(nil), // 3: xray.common.protocol.User
}
var file_proxy_wireguard_config_proto_depIdxs = []int32{
1, // 0: xray.proxy.wireguard.DeviceConfig.peers:type_name -> xray.proxy.wireguard.PeerConfig
0, // 1: xray.proxy.wireguard.DeviceConfig.domain_strategy:type_name -> xray.proxy.wireguard.DeviceConfig.DomainStrategy
2, // [2:2] is the sub-list for method output_type
2, // [2:2] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
3, // 1: xray.proxy.wireguard.DeviceConfig.users:type_name -> xray.common.protocol.User
0, // 2: xray.proxy.wireguard.DeviceConfig.domain_strategy:type_name -> xray.proxy.wireguard.DeviceConfig.DomainStrategy
3, // [3:3] is the sub-list for method output_type
3, // [3:3] is the sub-list for method input_type
3, // [3:3] is the sub-list for extension type_name
3, // [3:3] is the sub-list for extension extendee
0, // [0:3] is the sub-list for field type_name
}
func init() { file_proxy_wireguard_config_proto_init() }
+3
View File
@@ -6,6 +6,8 @@ option go_package = "github.com/xtls/xray-core/proxy/wireguard";
option java_package = "com.xray.proxy.wireguard";
option java_multiple_files = true;
import "common/protocol/user.proto";
message PeerConfig {
string public_key = 1;
string pre_shared_key = 2;
@@ -25,6 +27,7 @@ message DeviceConfig {
string secret_key = 1;
repeated string endpoint = 2;
repeated PeerConfig peers = 3;
repeated xray.common.protocol.User users = 5;
int32 mtu = 4;
bytes reserved = 6;
+156 -16
View File
@@ -2,16 +2,20 @@ package wireguard
import (
"context"
"encoding/hex"
"fmt"
"net/netip"
"reflect"
"strings"
"sync"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/log"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy"
@@ -20,6 +24,7 @@ import (
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/stat"
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -42,6 +47,9 @@ type Server struct {
stack *stack.Stack
dev *device.Device
mu sync.Mutex
pub [32]byte
users *sync.Map
}
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
@@ -72,15 +80,6 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
}
}
if len(conf.Peers) == 0 {
return nil, errors.New("empty peers")
}
for _, peer := range conf.Peers {
if peer.PublicKey == "" {
return nil, errors.New("peer without publickey")
}
}
localAddresses := make([]netip.Addr, 0, len(conf.Endpoint))
for _, localaddress := range conf.Endpoint {
addr, err := netip.ParseAddr(localaddress)
@@ -101,6 +100,19 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
return nil, err
}
pri := common.Must2(ParseKey(conf.SecretKey))
var pub [32]byte
curve25519.ScalarBaseMult(&pub, pri)
users := &sync.Map{}
for _, u := range conf.Users {
user, err := u.ToMemoryUser()
if err != nil {
return nil, err
}
users.Store(user.Account.(*MemoryAccount).Pub, user)
}
return &Server{
conf: conf,
ctx: core.ToBackgroundDetachedContext(ctx),
@@ -116,9 +128,100 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
tun: tun,
stack: stack,
pub: pub,
users: users,
}, nil
}
func (s *Server) AddUser(ctx context.Context, user *protocol.MemoryUser) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.dev == nil {
return errors.New("too early")
}
peer := user.Account.(*MemoryAccount)
if peer.Pub == s.pub {
return errors.New("invalid public key")
}
var sb strings.Builder
sb.WriteString("public_key=" + hex.EncodeToString(peer.Pub[:]) + "\n")
sb.WriteString("replace_allowed_ips=true\n")
for i := range peer.AllowedIPs {
sb.WriteString("allowed_ip=" + peer.AllowedIPs[i].String() + "\n")
}
if peer.PreSharedKey != "" {
sb.WriteString("preshared_key=" + peer.PreSharedKey + "\n")
}
if peer.KeepAlive != "" {
sb.WriteString("persistent_keepalive_interval=" + peer.KeepAlive + "\n")
}
err := s.dev.IpcSet(sb.String())
if err != nil {
return err
}
s.users.Store(peer.Pub, user)
return nil
}
func (s *Server) RemoveUser(ctx context.Context, email string) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.dev == nil {
return errors.New("too early")
}
if user := s.GetUser(ctx, email); user != nil {
peer := user.Account.(*MemoryAccount)
err := s.dev.IpcSet("public_key=" + hex.EncodeToString(peer.Pub[:]) + "\nremove=true\n")
if err != nil {
return err
}
s.users.Delete(peer.Pub)
}
return nil
}
func (s *Server) GetUser(ctx context.Context, email string) (user *protocol.MemoryUser) {
s.users.Range(func(key, value any) bool {
if value.(*protocol.MemoryUser).Email == email {
user = value.(*protocol.MemoryUser)
return false
}
return true
})
return
}
func (s *Server) GetUserByAddr(ctx context.Context, addr netip.Addr) (user *protocol.MemoryUser) {
s.users.Range(func(key, value any) bool {
peer := value.(*protocol.MemoryUser).Account.(*MemoryAccount)
for i := range peer.AllowedIPs {
if peer.AllowedIPs[i].Contains(addr) {
user = value.(*protocol.MemoryUser)
return false
}
}
return true
})
return
}
func (s *Server) GetUsers(ctx context.Context) (users []*protocol.MemoryUser) {
s.users.Range(func(key, value interface{}) bool {
users = append(users, value.(*protocol.MemoryUser))
return true
})
return
}
func (s *Server) GetUsersCount(context.Context) (count int64) {
s.users.Range(func(key, value interface{}) bool {
count++
return true
})
return
}
// Network implements proxy.Inbound.Network.
func (*Server) Network() []net.Network {
return []net.Network{}
@@ -196,18 +299,20 @@ func (s *Server) Start() error {
dev := device.NewDevice(s.tun, bind, logger)
var cfg strings.Builder
cfg.WriteString("private_key=" + s.conf.SecretKey + "\n")
for _, peer := range s.conf.Peers {
cfg.WriteString("public_key=" + peer.PublicKey + "\n")
s.users.Range(func(key, value any) bool {
peer := value.(*protocol.MemoryUser).Account.(*MemoryAccount)
cfg.WriteString("public_key=" + hex.EncodeToString(peer.Pub[:]) + "\n")
for i := range peer.AllowedIPs {
cfg.WriteString("allowed_ip=" + peer.AllowedIPs[i].String() + "\n")
}
if peer.PreSharedKey != "" {
cfg.WriteString("preshared_key=" + peer.PreSharedKey + "\n")
}
for _, ip := range peer.AllowedIps {
cfg.WriteString("allowed_ip=" + ip + "\n")
}
if peer.KeepAlive != "" {
cfg.WriteString("persistent_keepalive_interval=" + peer.KeepAlive + "\n")
}
}
return true
})
err := dev.IpcSet(cfg.String())
if err != nil {
return err
@@ -227,12 +332,36 @@ func (s *Server) HandleConnection(conn net.Conn, dest net.Destination) {
defer cancel()
ctx = c.ContextWithID(ctx, session.NewID())
source := net.DestinationFromAddr(conn.RemoteAddr())
remote := conn.RemoteAddr()
if remote == nil {
errors.LogError(context.Background(), "nil remote")
return
}
var addr netip.Addr
switch v := remote.(type) {
case *net.TCPAddr:
addr, _ = netip.AddrFromSlice(v.IP)
case *net.UDPAddr:
addr, _ = netip.AddrFromSlice(v.IP)
default:
errors.LogError(context.Background(), "invalid addr type ", reflect.TypeOf(v))
return
}
user := s.GetUserByAddr(context.TODO(), addr)
if user == nil {
errors.LogError(context.Background(), "nil user form ", remote, " to ", dest)
return
}
source := net.DestinationFromAddr(remote)
inbound := session.Inbound{
Name: "wireguard",
Tag: s.tag,
CanSpliceCopy: 3,
Source: source,
User: user,
}
ctx = session.ContextWithInbound(ctx, &inbound)
@@ -257,3 +386,14 @@ func (s *Server) HandleConnection(conn net.Conn, dest net.Destination) {
errors.LogError(ctx, errors.New("connection closed").Base(err))
}
}
func ParseKey(str string) (*[32]byte, error) {
slice, err := hex.DecodeString(str)
if err != nil {
return nil, err
}
if len(slice) != 32 {
return nil, errors.New("len(slice) != 32")
}
return (*[32]byte)(slice), nil
}