mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-06-29 16:33:05 +00:00
WireGuard inbound: Support dynamic peer management (#6360)
https://github.com/XTLS/Xray-core/pull/6360#issuecomment-4780311547 Closes https://github.com/XTLS/Xray-core/issues/6314 --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: LjhAUMEM <llnu14702@gmail.com>
This commit is contained in:
@@ -24,8 +24,7 @@ const (
|
||||
|
||||
type ClientConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Version int32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"`
|
||||
Server *protocol.ServerEndpoint `protobuf:"bytes,2,opt,name=server,proto3" json:"server,omitempty"`
|
||||
Server *protocol.ServerEndpoint `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -60,13 +59,6 @@ func (*ClientConfig) Descriptor() ([]byte, []int) {
|
||||
return file_proxy_hysteria_config_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *ClientConfig) GetVersion() int32 {
|
||||
if x != nil {
|
||||
return x.Version
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *ClientConfig) GetServer() *protocol.ServerEndpoint {
|
||||
if x != nil {
|
||||
return x.Server
|
||||
@@ -122,10 +114,9 @@ var File_proxy_hysteria_config_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_proxy_hysteria_config_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\x1a\x1acommon/protocol/user.proto\"f\n" +
|
||||
"\fClientConfig\x12\x18\n" +
|
||||
"\aversion\x18\x01 \x01(\x05R\aversion\x12<\n" +
|
||||
"\x06server\x18\x02 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06server\"@\n" +
|
||||
"\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\x1a\x1acommon/protocol/user.proto\"L\n" +
|
||||
"\fClientConfig\x12<\n" +
|
||||
"\x06server\x18\x01 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06server\"@\n" +
|
||||
"\fServerConfig\x120\n" +
|
||||
"\x05users\x18\x01 \x03(\v2\x1a.xray.common.protocol.UserR\x05usersB[\n" +
|
||||
"\x17com.xray.proxy.hysteriaP\x01Z(github.com/xtls/xray-core/proxy/hysteria\xaa\x02\x13Xray.Proxy.Hysteriab\x06proto3"
|
||||
|
||||
@@ -10,8 +10,7 @@ import "common/protocol/server_spec.proto";
|
||||
import "common/protocol/user.proto";
|
||||
|
||||
message ClientConfig {
|
||||
int32 version = 1;
|
||||
xray.common.protocol.ServerEndpoint server = 2;
|
||||
xray.common.protocol.ServerEndpoint server = 1;
|
||||
}
|
||||
|
||||
message ServerConfig {
|
||||
|
||||
@@ -1 +1,60 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func (p *PeerConfig) AsAccount() (protocol.Account, error) {
|
||||
pub, err := ParseKey(p.PublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowedIPs := make([]netip.Prefix, 0, len(p.AllowedIps))
|
||||
for i := range p.AllowedIps {
|
||||
p, err := netip.ParsePrefix(p.AllowedIps[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowedIPs = append(allowedIPs, p)
|
||||
}
|
||||
|
||||
return &MemoryAccount{
|
||||
Pub: *pub,
|
||||
AllowedIPs: allowedIPs,
|
||||
PreSharedKey: p.PreSharedKey,
|
||||
KeepAlive: p.KeepAlive,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type MemoryAccount struct {
|
||||
Pub [32]byte
|
||||
AllowedIPs []netip.Prefix
|
||||
PreSharedKey string
|
||||
KeepAlive string
|
||||
}
|
||||
|
||||
func (a *MemoryAccount) Equals(other protocol.Account) bool {
|
||||
if b, ok := other.(*MemoryAccount); ok {
|
||||
return a.Pub == b.Pub
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *MemoryAccount) ToProto() proto.Message {
|
||||
allowedIPs := make([]string, 0, len(a.AllowedIPs))
|
||||
for i := range a.AllowedIPs {
|
||||
allowedIPs = append(allowedIPs, a.AllowedIPs[i].String())
|
||||
}
|
||||
|
||||
return &PeerConfig{
|
||||
PublicKey: hex.EncodeToString(a.Pub[:]),
|
||||
AllowedIps: allowedIPs,
|
||||
PreSharedKey: a.PreSharedKey,
|
||||
KeepAlive: a.KeepAlive,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
protocol "github.com/xtls/xray-core/common/protocol"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
@@ -157,6 +158,7 @@ type DeviceConfig struct {
|
||||
SecretKey string `protobuf:"bytes,1,opt,name=secret_key,json=secretKey,proto3" json:"secret_key,omitempty"`
|
||||
Endpoint []string `protobuf:"bytes,2,rep,name=endpoint,proto3" json:"endpoint,omitempty"`
|
||||
Peers []*PeerConfig `protobuf:"bytes,3,rep,name=peers,proto3" json:"peers,omitempty"`
|
||||
Users []*protocol.User `protobuf:"bytes,5,rep,name=users,proto3" json:"users,omitempty"`
|
||||
Mtu int32 `protobuf:"varint,4,opt,name=mtu,proto3" json:"mtu,omitempty"`
|
||||
Reserved []byte `protobuf:"bytes,6,opt,name=reserved,proto3" json:"reserved,omitempty"`
|
||||
DomainStrategy DeviceConfig_DomainStrategy `protobuf:"varint,7,opt,name=domain_strategy,json=domainStrategy,proto3,enum=xray.proxy.wireguard.DeviceConfig_DomainStrategy" json:"domain_strategy,omitempty"`
|
||||
@@ -217,6 +219,13 @@ func (x *DeviceConfig) GetPeers() []*PeerConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *DeviceConfig) GetUsers() []*protocol.User {
|
||||
if x != nil {
|
||||
return x.Users
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *DeviceConfig) GetMtu() int32 {
|
||||
if x != nil {
|
||||
return x.Mtu
|
||||
@@ -256,7 +265,7 @@ var File_proxy_wireguard_config_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_proxy_wireguard_config_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x1cproxy/wireguard/config.proto\x12\x14xray.proxy.wireguard\"\xad\x01\n" +
|
||||
"\x1cproxy/wireguard/config.proto\x12\x14xray.proxy.wireguard\x1a\x1acommon/protocol/user.proto\"\xad\x01\n" +
|
||||
"\n" +
|
||||
"PeerConfig\x12\x1d\n" +
|
||||
"\n" +
|
||||
@@ -266,12 +275,13 @@ const file_proxy_wireguard_config_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"keep_alive\x18\x04 \x01(\tR\tkeepAlive\x12\x1f\n" +
|
||||
"\vallowed_ips\x18\x05 \x03(\tR\n" +
|
||||
"allowedIps\"\xaa\x03\n" +
|
||||
"allowedIps\"\xdc\x03\n" +
|
||||
"\fDeviceConfig\x12\x1d\n" +
|
||||
"\n" +
|
||||
"secret_key\x18\x01 \x01(\tR\tsecretKey\x12\x1a\n" +
|
||||
"\bendpoint\x18\x02 \x03(\tR\bendpoint\x126\n" +
|
||||
"\x05peers\x18\x03 \x03(\v2 .xray.proxy.wireguard.PeerConfigR\x05peers\x12\x10\n" +
|
||||
"\x05peers\x18\x03 \x03(\v2 .xray.proxy.wireguard.PeerConfigR\x05peers\x120\n" +
|
||||
"\x05users\x18\x05 \x03(\v2\x1a.xray.common.protocol.UserR\x05users\x12\x10\n" +
|
||||
"\x03mtu\x18\x04 \x01(\x05R\x03mtu\x12\x1a\n" +
|
||||
"\breserved\x18\x06 \x01(\fR\breserved\x12Z\n" +
|
||||
"\x0fdomain_strategy\x18\a \x01(\x0e21.xray.proxy.wireguard.DeviceConfig.DomainStrategyR\x0edomainStrategy\x12\x1b\n" +
|
||||
@@ -305,15 +315,17 @@ var file_proxy_wireguard_config_proto_goTypes = []any{
|
||||
(DeviceConfig_DomainStrategy)(0), // 0: xray.proxy.wireguard.DeviceConfig.DomainStrategy
|
||||
(*PeerConfig)(nil), // 1: xray.proxy.wireguard.PeerConfig
|
||||
(*DeviceConfig)(nil), // 2: xray.proxy.wireguard.DeviceConfig
|
||||
(*protocol.User)(nil), // 3: xray.common.protocol.User
|
||||
}
|
||||
var file_proxy_wireguard_config_proto_depIdxs = []int32{
|
||||
1, // 0: xray.proxy.wireguard.DeviceConfig.peers:type_name -> xray.proxy.wireguard.PeerConfig
|
||||
0, // 1: xray.proxy.wireguard.DeviceConfig.domain_strategy:type_name -> xray.proxy.wireguard.DeviceConfig.DomainStrategy
|
||||
2, // [2:2] is the sub-list for method output_type
|
||||
2, // [2:2] is the sub-list for method input_type
|
||||
2, // [2:2] is the sub-list for extension type_name
|
||||
2, // [2:2] is the sub-list for extension extendee
|
||||
0, // [0:2] is the sub-list for field type_name
|
||||
3, // 1: xray.proxy.wireguard.DeviceConfig.users:type_name -> xray.common.protocol.User
|
||||
0, // 2: xray.proxy.wireguard.DeviceConfig.domain_strategy:type_name -> xray.proxy.wireguard.DeviceConfig.DomainStrategy
|
||||
3, // [3:3] is the sub-list for method output_type
|
||||
3, // [3:3] is the sub-list for method input_type
|
||||
3, // [3:3] is the sub-list for extension type_name
|
||||
3, // [3:3] is the sub-list for extension extendee
|
||||
0, // [0:3] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_proxy_wireguard_config_proto_init() }
|
||||
|
||||
@@ -6,6 +6,8 @@ option go_package = "github.com/xtls/xray-core/proxy/wireguard";
|
||||
option java_package = "com.xray.proxy.wireguard";
|
||||
option java_multiple_files = true;
|
||||
|
||||
import "common/protocol/user.proto";
|
||||
|
||||
message PeerConfig {
|
||||
string public_key = 1;
|
||||
string pre_shared_key = 2;
|
||||
@@ -25,6 +27,7 @@ message DeviceConfig {
|
||||
string secret_key = 1;
|
||||
repeated string endpoint = 2;
|
||||
repeated PeerConfig peers = 3;
|
||||
repeated xray.common.protocol.User users = 5;
|
||||
int32 mtu = 4;
|
||||
|
||||
bytes reserved = 6;
|
||||
|
||||
+156
-16
@@ -2,16 +2,20 @@ package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
c "github.com/xtls/xray-core/common/ctx"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
@@ -20,6 +24,7 @@ import (
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
@@ -42,6 +47,9 @@ type Server struct {
|
||||
stack *stack.Stack
|
||||
dev *device.Device
|
||||
mu sync.Mutex
|
||||
|
||||
pub [32]byte
|
||||
users *sync.Map
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||
@@ -72,15 +80,6 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(conf.Peers) == 0 {
|
||||
return nil, errors.New("empty peers")
|
||||
}
|
||||
for _, peer := range conf.Peers {
|
||||
if peer.PublicKey == "" {
|
||||
return nil, errors.New("peer without publickey")
|
||||
}
|
||||
}
|
||||
|
||||
localAddresses := make([]netip.Addr, 0, len(conf.Endpoint))
|
||||
for _, localaddress := range conf.Endpoint {
|
||||
addr, err := netip.ParseAddr(localaddress)
|
||||
@@ -101,6 +100,19 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pri := common.Must2(ParseKey(conf.SecretKey))
|
||||
var pub [32]byte
|
||||
curve25519.ScalarBaseMult(&pub, pri)
|
||||
|
||||
users := &sync.Map{}
|
||||
for _, u := range conf.Users {
|
||||
user, err := u.ToMemoryUser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users.Store(user.Account.(*MemoryAccount).Pub, user)
|
||||
}
|
||||
|
||||
return &Server{
|
||||
conf: conf,
|
||||
ctx: core.ToBackgroundDetachedContext(ctx),
|
||||
@@ -116,9 +128,100 @@ func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
|
||||
|
||||
tun: tun,
|
||||
stack: stack,
|
||||
|
||||
pub: pub,
|
||||
users: users,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) AddUser(ctx context.Context, user *protocol.MemoryUser) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.dev == nil {
|
||||
return errors.New("too early")
|
||||
}
|
||||
peer := user.Account.(*MemoryAccount)
|
||||
if peer.Pub == s.pub {
|
||||
return errors.New("invalid public key")
|
||||
}
|
||||
var sb strings.Builder
|
||||
sb.WriteString("public_key=" + hex.EncodeToString(peer.Pub[:]) + "\n")
|
||||
sb.WriteString("replace_allowed_ips=true\n")
|
||||
for i := range peer.AllowedIPs {
|
||||
sb.WriteString("allowed_ip=" + peer.AllowedIPs[i].String() + "\n")
|
||||
}
|
||||
if peer.PreSharedKey != "" {
|
||||
sb.WriteString("preshared_key=" + peer.PreSharedKey + "\n")
|
||||
}
|
||||
if peer.KeepAlive != "" {
|
||||
sb.WriteString("persistent_keepalive_interval=" + peer.KeepAlive + "\n")
|
||||
}
|
||||
err := s.dev.IpcSet(sb.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.users.Store(peer.Pub, user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) RemoveUser(ctx context.Context, email string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.dev == nil {
|
||||
return errors.New("too early")
|
||||
}
|
||||
if user := s.GetUser(ctx, email); user != nil {
|
||||
peer := user.Account.(*MemoryAccount)
|
||||
err := s.dev.IpcSet("public_key=" + hex.EncodeToString(peer.Pub[:]) + "\nremove=true\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.users.Delete(peer.Pub)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) GetUser(ctx context.Context, email string) (user *protocol.MemoryUser) {
|
||||
s.users.Range(func(key, value any) bool {
|
||||
if value.(*protocol.MemoryUser).Email == email {
|
||||
user = value.(*protocol.MemoryUser)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) GetUserByAddr(ctx context.Context, addr netip.Addr) (user *protocol.MemoryUser) {
|
||||
s.users.Range(func(key, value any) bool {
|
||||
peer := value.(*protocol.MemoryUser).Account.(*MemoryAccount)
|
||||
for i := range peer.AllowedIPs {
|
||||
if peer.AllowedIPs[i].Contains(addr) {
|
||||
user = value.(*protocol.MemoryUser)
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) GetUsers(ctx context.Context) (users []*protocol.MemoryUser) {
|
||||
s.users.Range(func(key, value interface{}) bool {
|
||||
users = append(users, value.(*protocol.MemoryUser))
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Server) GetUsersCount(context.Context) (count int64) {
|
||||
s.users.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Network implements proxy.Inbound.Network.
|
||||
func (*Server) Network() []net.Network {
|
||||
return []net.Network{}
|
||||
@@ -196,18 +299,20 @@ func (s *Server) Start() error {
|
||||
dev := device.NewDevice(s.tun, bind, logger)
|
||||
var cfg strings.Builder
|
||||
cfg.WriteString("private_key=" + s.conf.SecretKey + "\n")
|
||||
for _, peer := range s.conf.Peers {
|
||||
cfg.WriteString("public_key=" + peer.PublicKey + "\n")
|
||||
s.users.Range(func(key, value any) bool {
|
||||
peer := value.(*protocol.MemoryUser).Account.(*MemoryAccount)
|
||||
cfg.WriteString("public_key=" + hex.EncodeToString(peer.Pub[:]) + "\n")
|
||||
for i := range peer.AllowedIPs {
|
||||
cfg.WriteString("allowed_ip=" + peer.AllowedIPs[i].String() + "\n")
|
||||
}
|
||||
if peer.PreSharedKey != "" {
|
||||
cfg.WriteString("preshared_key=" + peer.PreSharedKey + "\n")
|
||||
}
|
||||
for _, ip := range peer.AllowedIps {
|
||||
cfg.WriteString("allowed_ip=" + ip + "\n")
|
||||
}
|
||||
if peer.KeepAlive != "" {
|
||||
cfg.WriteString("persistent_keepalive_interval=" + peer.KeepAlive + "\n")
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
err := dev.IpcSet(cfg.String())
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -227,12 +332,36 @@ func (s *Server) HandleConnection(conn net.Conn, dest net.Destination) {
|
||||
defer cancel()
|
||||
ctx = c.ContextWithID(ctx, session.NewID())
|
||||
|
||||
source := net.DestinationFromAddr(conn.RemoteAddr())
|
||||
remote := conn.RemoteAddr()
|
||||
if remote == nil {
|
||||
errors.LogError(context.Background(), "nil remote")
|
||||
return
|
||||
}
|
||||
|
||||
var addr netip.Addr
|
||||
switch v := remote.(type) {
|
||||
case *net.TCPAddr:
|
||||
addr, _ = netip.AddrFromSlice(v.IP)
|
||||
case *net.UDPAddr:
|
||||
addr, _ = netip.AddrFromSlice(v.IP)
|
||||
default:
|
||||
errors.LogError(context.Background(), "invalid addr type ", reflect.TypeOf(v))
|
||||
return
|
||||
}
|
||||
|
||||
user := s.GetUserByAddr(context.TODO(), addr)
|
||||
if user == nil {
|
||||
errors.LogError(context.Background(), "nil user form ", remote, " to ", dest)
|
||||
return
|
||||
}
|
||||
|
||||
source := net.DestinationFromAddr(remote)
|
||||
inbound := session.Inbound{
|
||||
Name: "wireguard",
|
||||
Tag: s.tag,
|
||||
CanSpliceCopy: 3,
|
||||
Source: source,
|
||||
User: user,
|
||||
}
|
||||
|
||||
ctx = session.ContextWithInbound(ctx, &inbound)
|
||||
@@ -257,3 +386,14 @@ func (s *Server) HandleConnection(conn net.Conn, dest net.Destination) {
|
||||
errors.LogError(ctx, errors.New("connection closed").Base(err))
|
||||
}
|
||||
}
|
||||
|
||||
func ParseKey(str string) (*[32]byte, error) {
|
||||
slice, err := hex.DecodeString(str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(slice) != 32 {
|
||||
return nil, errors.New("len(slice) != 32")
|
||||
}
|
||||
return (*[32]byte)(slice), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user