Refactor tls cert hot reload

This commit is contained in:
Fangliding
2026-06-14 03:24:08 +08:00
parent d792fba59c
commit 1b8d07f1e3
6 changed files with 211 additions and 115 deletions
+14
View File
@@ -1,6 +1,7 @@
package utils
import (
"maps"
"runtime"
"sync"
"weak"
@@ -43,3 +44,16 @@ func (c *WeakCacheMap[K, V]) Store(key K, value *V) {
}
}, struct{}{})
}
func (c *WeakCacheMap[K, V]) Range(f func(K, *V) bool) {
c.mu.Lock()
snapshot := maps.Clone(c.m)
c.mu.Unlock()
for k, v := range snapshot {
if value := v.Value(); value != nil {
if !f(k, value) {
break
}
}
}
}
+73 -110
View File
@@ -11,12 +11,12 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/ocsp"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/protocol/tls/cert"
"github.com/xtls/xray-core/transport/internet"
)
@@ -45,91 +45,6 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
return root, nil
}
// BuildCertificates builds a list of TLS certificates from proto definition.
func (c *Config) BuildCertificates() []*tls.Certificate {
certs := make([]*tls.Certificate, 0, len(c.Certificate))
for _, entry := range c.Certificate {
if entry.Usage != Certificate_ENCIPHERMENT {
continue
}
getX509KeyPair := func() *tls.Certificate {
keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
return nil
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
return nil
}
return &keyPair
}
if keyPair := getX509KeyPair(); keyPair != nil {
certs = append(certs, keyPair)
} else {
continue
}
index := len(certs) - 1
setupOcspTicker(entry, func(isReloaded, isOcspstapling bool) {
cert := certs[index]
if isReloaded {
if newKeyPair := getX509KeyPair(); newKeyPair != nil {
cert = newKeyPair
} else {
return
}
}
if isOcspstapling {
if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
} else if string(newOCSPData) != string(cert.OCSPStaple) {
cert.OCSPStaple = newOCSPData
}
}
certs[index] = cert
})
}
return certs
}
func setupOcspTicker(entry *Certificate, callback func(isReloaded, isOcspstapling bool)) {
go func() {
if entry.OneTimeLoading {
return
}
var isOcspstapling bool
hotReloadCertInterval := uint64(3600)
if entry.OcspStapling != 0 {
hotReloadCertInterval = entry.OcspStapling
isOcspstapling = true
}
t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
for {
var isReloaded bool
if entry.CertificatePath != "" && entry.KeyPath != "" {
newCert, err := filesystem.ReadCert(entry.CertificatePath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
return
}
newKey, err := filesystem.ReadCert(entry.KeyPath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse key")
return
}
if string(newCert) != string(entry.Certificate) || string(newKey) != string(entry.Key) {
entry.Certificate = newCert
entry.Key = newKey
isReloaded = true
}
}
callback(isReloaded, isOcspstapling)
<-t.C
}
}()
}
func isCertificateExpired(c *tls.Certificate) bool {
if c.Leaf == nil && len(c.Certificate) > 0 {
if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
@@ -163,7 +78,7 @@ func (c *Config) getCustomCA() []*Certificate {
for _, certificate := range c.Certificate {
if certificate.Usage == Certificate_AUTHORITY_ISSUE {
certs = append(certs, certificate)
setupOcspTicker(certificate, func(isReloaded, isOcspstapling bool) {})
setupHotReload(certificate)
}
}
return certs
@@ -243,34 +158,77 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
}
}
func getNewGetCertificateFunc(certs []*tls.Certificate, rejectUnknownSNI bool) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if len(certs) == 0 {
return nil, errNoCertificates
// atomic.Pointer must be exactly one pointer word
var _ [unsafe.Sizeof(unsafe.Pointer(nil))]byte = [unsafe.Sizeof(atomic.Pointer[tls.Certificate]{})]byte{}
func (c *Certificate) parseX509KeyPair() *tls.Certificate {
keyPair, err := tls.X509KeyPair(c.Certificate, c.Key)
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
return nil
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
return nil
}
if len(c.OcspData) > 0 {
keyPair.OCSPStaple = c.OcspData
}
// wtf is this
(*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Store(&keyPair)
return &keyPair
}
func (c *Certificate) getX509KeyPair() *tls.Certificate {
if keyPair := (*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Load(); keyPair != nil {
return keyPair
}
return c.parseX509KeyPair()
}
func (c *Config) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
var defaultCert *tls.Certificate
for _, cert := range c.Certificate {
if cert.Usage == Certificate_ENCIPHERMENT {
defaultCert = cert.getX509KeyPair()
if defaultCert != nil {
break
}
}
sni := strings.ToLower(hello.ServerName)
if !rejectUnknownSNI && (len(certs) == 1 || sni == "") {
return certs[0], nil
}
if defaultCert == nil {
return nil, errNoCertificates
}
sni := strings.ToLower(hello.ServerName)
if !c.RejectUnknownSni && (len(c.Certificate) == 1 || sni == "") {
return defaultCert, nil
}
gsni := "*"
if index := strings.IndexByte(sni, '.'); index != -1 {
gsni += sni[index:]
}
for _, rawCertificate := range c.Certificate {
if rawCertificate.Usage != Certificate_ENCIPHERMENT {
continue
}
gsni := "*"
if index := strings.IndexByte(sni, '.'); index != -1 {
gsni += sni[index:]
keyPair := rawCertificate.getX509KeyPair()
if keyPair == nil {
continue
}
for _, keyPair := range certs {
if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
return keyPair, nil
}
for _, name := range keyPair.Leaf.DNSNames {
if name == sni || name == gsni {
return keyPair, nil
}
for _, name := range keyPair.Leaf.DNSNames {
if name == sni || name == gsni {
return keyPair, nil
}
}
}
if rejectUnknownSNI {
return nil, errNoCertificates
}
return certs[0], nil
}
if c.RejectUnknownSni {
return nil, errNoCertificates
}
return defaultCert, nil
}
func (c *Config) parseServerName() string {
@@ -408,7 +366,12 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
if len(caCerts) > 0 {
config.GetCertificate = getGetCertificateFunc(config, caCerts)
} else {
config.GetCertificate = getNewGetCertificateFunc(c.BuildCertificates(), c.RejectUnknownSni)
for _, cert := range c.Certificate {
if cert.Usage == Certificate_ENCIPHERMENT {
setupHotReload(cert)
}
}
config.GetCertificate = c.getCertificate
}
if sn := c.parseServerName(); len(sn) > 0 {
+34 -4
View File
@@ -86,8 +86,12 @@ type Certificate struct {
// If true, one-Time Loading
OneTimeLoading bool `protobuf:"varint,7,opt,name=One_time_loading,json=OneTimeLoading,proto3" json:"One_time_loading,omitempty"`
BuildChain bool `protobuf:"varint,8,opt,name=build_chain,json=buildChain,proto3" json:"build_chain,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
// Abuse proto data for storage hotreload data
ParsedCache []byte `protobuf:"bytes,9,opt,name=parsed_cache,json=parsedCache,proto3" json:"parsed_cache,omitempty"`
OcspData []byte `protobuf:"bytes,10,opt,name=ocsp_data,json=ocspData,proto3" json:"ocsp_data,omitempty"`
LastReload int64 `protobuf:"varint,11,opt,name=last_reload,json=lastReload,proto3" json:"last_reload,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Certificate) Reset() {
@@ -176,6 +180,27 @@ func (x *Certificate) GetBuildChain() bool {
return false
}
func (x *Certificate) GetParsedCache() []byte {
if x != nil {
return x.ParsedCache
}
return nil
}
func (x *Certificate) GetOcspData() []byte {
if x != nil {
return x.OcspData
}
return nil
}
func (x *Certificate) GetLastReload() int64 {
if x != nil {
return x.LastReload
}
return 0
}
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
// List of certificates to be served on server.
@@ -363,7 +388,7 @@ var File_transport_internet_tls_config_proto protoreflect.FileDescriptor
const file_transport_internet_tls_config_proto_rawDesc = "" +
"\n" +
"#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\x83\x03\n" +
"#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\xe4\x03\n" +
"\vCertificate\x12 \n" +
"\vcertificate\x18\x01 \x01(\fR\vcertificate\x12\x10\n" +
"\x03key\x18\x02 \x01(\fR\x03key\x12D\n" +
@@ -373,7 +398,12 @@ const file_transport_internet_tls_config_proto_rawDesc = "" +
"\bkey_path\x18\x06 \x01(\tR\akeyPath\x12(\n" +
"\x10One_time_loading\x18\a \x01(\bR\x0eOneTimeLoading\x12\x1f\n" +
"\vbuild_chain\x18\b \x01(\bR\n" +
"buildChain\"D\n" +
"buildChain\x12!\n" +
"\fparsed_cache\x18\t \x01(\fR\vparsedCache\x12\x1b\n" +
"\tocsp_data\x18\n" +
" \x01(\fR\bocspData\x12\x1f\n" +
"\vlast_reload\x18\v \x01(\x03R\n" +
"lastReload\"D\n" +
"\x05Usage\x12\x10\n" +
"\fENCIPHERMENT\x10\x00\x12\x14\n" +
"\x10AUTHORITY_VERIFY\x10\x01\x12\x13\n" +
+5
View File
@@ -35,6 +35,11 @@ message Certificate {
bool One_time_loading = 7;
bool build_chain = 8;
// Abuse proto data for storage hotreload data
bytes parsed_cache = 9;
bytes ocsp_data = 10;
int64 last_reload = 11;
}
message Config {
+1 -1
View File
@@ -157,7 +157,7 @@ func QueryRecord(domain string, server string, sockopt *internet.SocketConfig) (
// If expire is zero value, it means we are in initial state, wait for the query to finish
// otherwise return old value immediately and update in a goroutine
// but if the cache is too old, wait for update
if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*4).Before(time.Now()) {
if configRecord.expire.Equal(time.Time{}) || configRecord.expire.Add(time.Hour*4).Before(time.Now()) {
return echConfigCache.Update(domain, server, false, sockopt)
} else {
// If someone already acquired the lock, it means it is updating, do not start another update goroutine
+84
View File
@@ -0,0 +1,84 @@
package tls
import (
"context"
"slices"
"sync"
"time"
"unsafe"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/ocsp"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/utils"
)
var certsCache = utils.NewWeakCacheMap[uintptr, Certificate]()
var startHotReload sync.Once
func setupHotReload(entry *Certificate) {
startHotReload.Do(func() {
go handleHotReload()
})
// ensure the cache before use
entry.getX509KeyPair()
if entry.OneTimeLoading {
return
}
uptr := uintptr(unsafe.Pointer(entry))
if _, ok := certsCache.Load(uptr); !ok {
certsCache.Store(uptr, entry)
}
}
func handleHotReload() {
// should be enough?
t := time.NewTicker(600 * time.Second)
for {
certsCache.Range(updateCert)
<-t.C
}
}
func updateCert(_ uintptr, entry *Certificate) bool {
reloadInterval := int64(entry.OcspStapling)
if reloadInterval <= 0 {
reloadInterval = 3600
}
if entry.LastReload+reloadInterval >= time.Now().Unix() {
return true
} else {
entry.LastReload = time.Now().Unix()
}
if entry.CertificatePath != "" && entry.KeyPath != "" {
newCert, err := filesystem.ReadCert(entry.CertificatePath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
return true
}
newKey, err := filesystem.ReadCert(entry.KeyPath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse key")
return true
}
if string(newCert) != string(entry.Certificate) || string(newKey) != string(entry.Key) {
entry.Certificate = newCert
entry.Key = newKey
}
}
entry.parseX509KeyPair()
if entry.OcspStapling > 0 {
keyPair := entry.getX509KeyPair()
if keyPair == nil {
return true
}
if newOCSPData, err := ocsp.GetOCSPForCert(keyPair.Certificate); err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
} else if !slices.Equal(newOCSPData, entry.OcspData) {
entry.OcspData = newOCSPData
}
entry.parseX509KeyPair()
}
return true
}