diff --git a/common/utils/weak_cache.go b/common/utils/weak_cache.go index a9aaacca6..48cc8ce07 100644 --- a/common/utils/weak_cache.go +++ b/common/utils/weak_cache.go @@ -1,6 +1,7 @@ package utils import ( + "maps" "runtime" "sync" "weak" @@ -43,3 +44,16 @@ func (c *WeakCacheMap[K, V]) Store(key K, value *V) { } }, struct{}{}) } + +func (c *WeakCacheMap[K, V]) Range(f func(K, *V) bool) { + c.mu.Lock() + snapshot := maps.Clone(c.m) + c.mu.Unlock() + for k, v := range snapshot { + if value := v.Value(); value != nil { + if !f(k, value) { + break + } + } + } +} diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 9c8604699..1a635f993 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -11,12 +11,12 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" + "unsafe" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/ocsp" - "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/protocol/tls/cert" "github.com/xtls/xray-core/transport/internet" ) @@ -45,91 +45,6 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) { return root, nil } -// BuildCertificates builds a list of TLS certificates from proto definition. -func (c *Config) BuildCertificates() []*tls.Certificate { - certs := make([]*tls.Certificate, 0, len(c.Certificate)) - for _, entry := range c.Certificate { - if entry.Usage != Certificate_ENCIPHERMENT { - continue - } - getX509KeyPair := func() *tls.Certificate { - keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key) - if err != nil { - errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair") - return nil - } - keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) - if err != nil { - errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate") - return nil - } - return &keyPair - } - if keyPair := getX509KeyPair(); keyPair != nil { - certs = append(certs, keyPair) - } else { - continue - } - index := len(certs) - 1 - setupOcspTicker(entry, func(isReloaded, isOcspstapling bool) { - cert := certs[index] - if isReloaded { - if newKeyPair := getX509KeyPair(); newKeyPair != nil { - cert = newKeyPair - } else { - return - } - } - if isOcspstapling { - if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil { - errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP") - } else if string(newOCSPData) != string(cert.OCSPStaple) { - cert.OCSPStaple = newOCSPData - } - } - certs[index] = cert - }) - } - return certs -} - -func setupOcspTicker(entry *Certificate, callback func(isReloaded, isOcspstapling bool)) { - go func() { - if entry.OneTimeLoading { - return - } - var isOcspstapling bool - hotReloadCertInterval := uint64(3600) - if entry.OcspStapling != 0 { - hotReloadCertInterval = entry.OcspStapling - isOcspstapling = true - } - t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second) - for { - var isReloaded bool - if entry.CertificatePath != "" && entry.KeyPath != "" { - newCert, err := filesystem.ReadCert(entry.CertificatePath) - if err != nil { - errors.LogErrorInner(context.Background(), err, "failed to parse certificate") - return - } - newKey, err := filesystem.ReadCert(entry.KeyPath) - if err != nil { - errors.LogErrorInner(context.Background(), err, "failed to parse key") - return - } - if string(newCert) != string(entry.Certificate) || string(newKey) != string(entry.Key) { - entry.Certificate = newCert - entry.Key = newKey - isReloaded = true - } - } - callback(isReloaded, isOcspstapling) - <-t.C - } - }() -} - func isCertificateExpired(c *tls.Certificate) bool { if c.Leaf == nil && len(c.Certificate) > 0 { if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil { @@ -163,7 +78,7 @@ func (c *Config) getCustomCA() []*Certificate { for _, certificate := range c.Certificate { if certificate.Usage == Certificate_AUTHORITY_ISSUE { certs = append(certs, certificate) - setupOcspTicker(certificate, func(isReloaded, isOcspstapling bool) {}) + setupHotReload(certificate) } } return certs @@ -243,34 +158,77 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli } } -func getNewGetCertificateFunc(certs []*tls.Certificate, rejectUnknownSNI bool) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - if len(certs) == 0 { - return nil, errNoCertificates +// atomic.Pointer must be exactly one pointer word +var _ [unsafe.Sizeof(unsafe.Pointer(nil))]byte = [unsafe.Sizeof(atomic.Pointer[tls.Certificate]{})]byte{} + +func (c *Certificate) parseX509KeyPair() *tls.Certificate { + keyPair, err := tls.X509KeyPair(c.Certificate, c.Key) + if err != nil { + errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair") + return nil + } + keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) + if err != nil { + errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate") + return nil + } + if len(c.OcspData) > 0 { + keyPair.OCSPStaple = c.OcspData + } + // wtf is this + (*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Store(&keyPair) + return &keyPair +} + +func (c *Certificate) getX509KeyPair() *tls.Certificate { + if keyPair := (*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Load(); keyPair != nil { + return keyPair + } + return c.parseX509KeyPair() +} + +func (c *Config) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + var defaultCert *tls.Certificate + for _, cert := range c.Certificate { + if cert.Usage == Certificate_ENCIPHERMENT { + defaultCert = cert.getX509KeyPair() + if defaultCert != nil { + break + } } - sni := strings.ToLower(hello.ServerName) - if !rejectUnknownSNI && (len(certs) == 1 || sni == "") { - return certs[0], nil + } + if defaultCert == nil { + return nil, errNoCertificates + } + sni := strings.ToLower(hello.ServerName) + if !c.RejectUnknownSni && (len(c.Certificate) == 1 || sni == "") { + return defaultCert, nil + } + gsni := "*" + if index := strings.IndexByte(sni, '.'); index != -1 { + gsni += sni[index:] + } + for _, rawCertificate := range c.Certificate { + if rawCertificate.Usage != Certificate_ENCIPHERMENT { + continue } - gsni := "*" - if index := strings.IndexByte(sni, '.'); index != -1 { - gsni += sni[index:] + keyPair := rawCertificate.getX509KeyPair() + if keyPair == nil { + continue } - for _, keyPair := range certs { - if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { + if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { + return keyPair, nil + } + for _, name := range keyPair.Leaf.DNSNames { + if name == sni || name == gsni { return keyPair, nil } - for _, name := range keyPair.Leaf.DNSNames { - if name == sni || name == gsni { - return keyPair, nil - } - } } - if rejectUnknownSNI { - return nil, errNoCertificates - } - return certs[0], nil } + if c.RejectUnknownSni { + return nil, errNoCertificates + } + return defaultCert, nil } func (c *Config) parseServerName() string { @@ -408,7 +366,12 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { if len(caCerts) > 0 { config.GetCertificate = getGetCertificateFunc(config, caCerts) } else { - config.GetCertificate = getNewGetCertificateFunc(c.BuildCertificates(), c.RejectUnknownSni) + for _, cert := range c.Certificate { + if cert.Usage == Certificate_ENCIPHERMENT { + setupHotReload(cert) + } + } + config.GetCertificate = c.getCertificate } if sn := c.parseServerName(); len(sn) > 0 { diff --git a/transport/internet/tls/config.pb.go b/transport/internet/tls/config.pb.go index b622d5a61..01822d2cd 100644 --- a/transport/internet/tls/config.pb.go +++ b/transport/internet/tls/config.pb.go @@ -86,8 +86,12 @@ type Certificate struct { // If true, one-Time Loading OneTimeLoading bool `protobuf:"varint,7,opt,name=One_time_loading,json=OneTimeLoading,proto3" json:"One_time_loading,omitempty"` BuildChain bool `protobuf:"varint,8,opt,name=build_chain,json=buildChain,proto3" json:"build_chain,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // Abuse proto data for storage hotreload data + ParsedCache []byte `protobuf:"bytes,9,opt,name=parsed_cache,json=parsedCache,proto3" json:"parsed_cache,omitempty"` + OcspData []byte `protobuf:"bytes,10,opt,name=ocsp_data,json=ocspData,proto3" json:"ocsp_data,omitempty"` + LastReload int64 `protobuf:"varint,11,opt,name=last_reload,json=lastReload,proto3" json:"last_reload,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *Certificate) Reset() { @@ -176,6 +180,27 @@ func (x *Certificate) GetBuildChain() bool { return false } +func (x *Certificate) GetParsedCache() []byte { + if x != nil { + return x.ParsedCache + } + return nil +} + +func (x *Certificate) GetOcspData() []byte { + if x != nil { + return x.OcspData + } + return nil +} + +func (x *Certificate) GetLastReload() int64 { + if x != nil { + return x.LastReload + } + return 0 +} + type Config struct { state protoimpl.MessageState `protogen:"open.v1"` // List of certificates to be served on server. @@ -363,7 +388,7 @@ var File_transport_internet_tls_config_proto protoreflect.FileDescriptor const file_transport_internet_tls_config_proto_rawDesc = "" + "\n" + - "#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\x83\x03\n" + + "#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\xe4\x03\n" + "\vCertificate\x12 \n" + "\vcertificate\x18\x01 \x01(\fR\vcertificate\x12\x10\n" + "\x03key\x18\x02 \x01(\fR\x03key\x12D\n" + @@ -373,7 +398,12 @@ const file_transport_internet_tls_config_proto_rawDesc = "" + "\bkey_path\x18\x06 \x01(\tR\akeyPath\x12(\n" + "\x10One_time_loading\x18\a \x01(\bR\x0eOneTimeLoading\x12\x1f\n" + "\vbuild_chain\x18\b \x01(\bR\n" + - "buildChain\"D\n" + + "buildChain\x12!\n" + + "\fparsed_cache\x18\t \x01(\fR\vparsedCache\x12\x1b\n" + + "\tocsp_data\x18\n" + + " \x01(\fR\bocspData\x12\x1f\n" + + "\vlast_reload\x18\v \x01(\x03R\n" + + "lastReload\"D\n" + "\x05Usage\x12\x10\n" + "\fENCIPHERMENT\x10\x00\x12\x14\n" + "\x10AUTHORITY_VERIFY\x10\x01\x12\x13\n" + diff --git a/transport/internet/tls/config.proto b/transport/internet/tls/config.proto index a05cc0494..8801b4998 100644 --- a/transport/internet/tls/config.proto +++ b/transport/internet/tls/config.proto @@ -35,6 +35,11 @@ message Certificate { bool One_time_loading = 7; bool build_chain = 8; + + // Abuse proto data for storage hotreload data + bytes parsed_cache = 9; + bytes ocsp_data = 10; + int64 last_reload = 11; } message Config { diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 9de124c1c..942ff0807 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -157,7 +157,7 @@ func QueryRecord(domain string, server string, sockopt *internet.SocketConfig) ( // If expire is zero value, it means we are in initial state, wait for the query to finish // otherwise return old value immediately and update in a goroutine // but if the cache is too old, wait for update - if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*4).Before(time.Now()) { + if configRecord.expire.Equal(time.Time{}) || configRecord.expire.Add(time.Hour*4).Before(time.Now()) { return echConfigCache.Update(domain, server, false, sockopt) } else { // If someone already acquired the lock, it means it is updating, do not start another update goroutine diff --git a/transport/internet/tls/hot_reloader.go b/transport/internet/tls/hot_reloader.go new file mode 100644 index 000000000..0ca60fb05 --- /dev/null +++ b/transport/internet/tls/hot_reloader.go @@ -0,0 +1,84 @@ +package tls + +import ( + "context" + "slices" + "sync" + "time" + "unsafe" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/ocsp" + "github.com/xtls/xray-core/common/platform/filesystem" + "github.com/xtls/xray-core/common/utils" +) + +var certsCache = utils.NewWeakCacheMap[uintptr, Certificate]() + +var startHotReload sync.Once + +func setupHotReload(entry *Certificate) { + startHotReload.Do(func() { + go handleHotReload() + }) + // ensure the cache before use + entry.getX509KeyPair() + if entry.OneTimeLoading { + return + } + uptr := uintptr(unsafe.Pointer(entry)) + if _, ok := certsCache.Load(uptr); !ok { + certsCache.Store(uptr, entry) + } +} + +func handleHotReload() { + // should be enough? + t := time.NewTicker(600 * time.Second) + for { + certsCache.Range(updateCert) + <-t.C + } +} + +func updateCert(_ uintptr, entry *Certificate) bool { + reloadInterval := int64(entry.OcspStapling) + if reloadInterval <= 0 { + reloadInterval = 3600 + } + if entry.LastReload+reloadInterval >= time.Now().Unix() { + return true + } else { + entry.LastReload = time.Now().Unix() + } + if entry.CertificatePath != "" && entry.KeyPath != "" { + newCert, err := filesystem.ReadCert(entry.CertificatePath) + if err != nil { + errors.LogErrorInner(context.Background(), err, "failed to parse certificate") + return true + } + newKey, err := filesystem.ReadCert(entry.KeyPath) + if err != nil { + errors.LogErrorInner(context.Background(), err, "failed to parse key") + return true + } + if string(newCert) != string(entry.Certificate) || string(newKey) != string(entry.Key) { + entry.Certificate = newCert + entry.Key = newKey + } + } + entry.parseX509KeyPair() + if entry.OcspStapling > 0 { + keyPair := entry.getX509KeyPair() + if keyPair == nil { + return true + } + if newOCSPData, err := ocsp.GetOCSPForCert(keyPair.Certificate); err != nil { + errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP") + } else if !slices.Equal(newOCSPData, entry.OcspData) { + entry.OcspData = newOCSPData + } + entry.parseX509KeyPair() + } + return true +}