mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-07-03 10:18:42 +00:00
Refactor tls cert hot reload
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"runtime"
|
||||
"sync"
|
||||
"weak"
|
||||
@@ -43,3 +44,16 @@ func (c *WeakCacheMap[K, V]) Store(key K, value *V) {
|
||||
}
|
||||
}, struct{}{})
|
||||
}
|
||||
|
||||
func (c *WeakCacheMap[K, V]) Range(f func(K, *V) bool) {
|
||||
c.mu.Lock()
|
||||
snapshot := maps.Clone(c.m)
|
||||
c.mu.Unlock()
|
||||
for k, v := range snapshot {
|
||||
if value := v.Value(); value != nil {
|
||||
if !f(k, value) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/ocsp"
|
||||
"github.com/xtls/xray-core/common/platform/filesystem"
|
||||
"github.com/xtls/xray-core/common/protocol/tls/cert"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
)
|
||||
@@ -45,91 +45,6 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// BuildCertificates builds a list of TLS certificates from proto definition.
|
||||
func (c *Config) BuildCertificates() []*tls.Certificate {
|
||||
certs := make([]*tls.Certificate, 0, len(c.Certificate))
|
||||
for _, entry := range c.Certificate {
|
||||
if entry.Usage != Certificate_ENCIPHERMENT {
|
||||
continue
|
||||
}
|
||||
getX509KeyPair := func() *tls.Certificate {
|
||||
keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
|
||||
if err != nil {
|
||||
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
|
||||
return nil
|
||||
}
|
||||
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
|
||||
if err != nil {
|
||||
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
|
||||
return nil
|
||||
}
|
||||
return &keyPair
|
||||
}
|
||||
if keyPair := getX509KeyPair(); keyPair != nil {
|
||||
certs = append(certs, keyPair)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
index := len(certs) - 1
|
||||
setupOcspTicker(entry, func(isReloaded, isOcspstapling bool) {
|
||||
cert := certs[index]
|
||||
if isReloaded {
|
||||
if newKeyPair := getX509KeyPair(); newKeyPair != nil {
|
||||
cert = newKeyPair
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
if isOcspstapling {
|
||||
if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
|
||||
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
|
||||
} else if string(newOCSPData) != string(cert.OCSPStaple) {
|
||||
cert.OCSPStaple = newOCSPData
|
||||
}
|
||||
}
|
||||
certs[index] = cert
|
||||
})
|
||||
}
|
||||
return certs
|
||||
}
|
||||
|
||||
func setupOcspTicker(entry *Certificate, callback func(isReloaded, isOcspstapling bool)) {
|
||||
go func() {
|
||||
if entry.OneTimeLoading {
|
||||
return
|
||||
}
|
||||
var isOcspstapling bool
|
||||
hotReloadCertInterval := uint64(3600)
|
||||
if entry.OcspStapling != 0 {
|
||||
hotReloadCertInterval = entry.OcspStapling
|
||||
isOcspstapling = true
|
||||
}
|
||||
t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
|
||||
for {
|
||||
var isReloaded bool
|
||||
if entry.CertificatePath != "" && entry.KeyPath != "" {
|
||||
newCert, err := filesystem.ReadCert(entry.CertificatePath)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
|
||||
return
|
||||
}
|
||||
newKey, err := filesystem.ReadCert(entry.KeyPath)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(context.Background(), err, "failed to parse key")
|
||||
return
|
||||
}
|
||||
if string(newCert) != string(entry.Certificate) || string(newKey) != string(entry.Key) {
|
||||
entry.Certificate = newCert
|
||||
entry.Key = newKey
|
||||
isReloaded = true
|
||||
}
|
||||
}
|
||||
callback(isReloaded, isOcspstapling)
|
||||
<-t.C
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func isCertificateExpired(c *tls.Certificate) bool {
|
||||
if c.Leaf == nil && len(c.Certificate) > 0 {
|
||||
if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
|
||||
@@ -163,7 +78,7 @@ func (c *Config) getCustomCA() []*Certificate {
|
||||
for _, certificate := range c.Certificate {
|
||||
if certificate.Usage == Certificate_AUTHORITY_ISSUE {
|
||||
certs = append(certs, certificate)
|
||||
setupOcspTicker(certificate, func(isReloaded, isOcspstapling bool) {})
|
||||
setupHotReload(certificate)
|
||||
}
|
||||
}
|
||||
return certs
|
||||
@@ -243,34 +158,77 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
|
||||
}
|
||||
}
|
||||
|
||||
func getNewGetCertificateFunc(certs []*tls.Certificate, rejectUnknownSNI bool) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if len(certs) == 0 {
|
||||
return nil, errNoCertificates
|
||||
// atomic.Pointer must be exactly one pointer word
|
||||
var _ [unsafe.Sizeof(unsafe.Pointer(nil))]byte = [unsafe.Sizeof(atomic.Pointer[tls.Certificate]{})]byte{}
|
||||
|
||||
func (c *Certificate) parseX509KeyPair() *tls.Certificate {
|
||||
keyPair, err := tls.X509KeyPair(c.Certificate, c.Key)
|
||||
if err != nil {
|
||||
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
|
||||
return nil
|
||||
}
|
||||
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
|
||||
if err != nil {
|
||||
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
|
||||
return nil
|
||||
}
|
||||
if len(c.OcspData) > 0 {
|
||||
keyPair.OCSPStaple = c.OcspData
|
||||
}
|
||||
// wtf is this
|
||||
(*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Store(&keyPair)
|
||||
return &keyPair
|
||||
}
|
||||
|
||||
func (c *Certificate) getX509KeyPair() *tls.Certificate {
|
||||
if keyPair := (*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Load(); keyPair != nil {
|
||||
return keyPair
|
||||
}
|
||||
return c.parseX509KeyPair()
|
||||
}
|
||||
|
||||
func (c *Config) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
var defaultCert *tls.Certificate
|
||||
for _, cert := range c.Certificate {
|
||||
if cert.Usage == Certificate_ENCIPHERMENT {
|
||||
defaultCert = cert.getX509KeyPair()
|
||||
if defaultCert != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
sni := strings.ToLower(hello.ServerName)
|
||||
if !rejectUnknownSNI && (len(certs) == 1 || sni == "") {
|
||||
return certs[0], nil
|
||||
}
|
||||
if defaultCert == nil {
|
||||
return nil, errNoCertificates
|
||||
}
|
||||
sni := strings.ToLower(hello.ServerName)
|
||||
if !c.RejectUnknownSni && (len(c.Certificate) == 1 || sni == "") {
|
||||
return defaultCert, nil
|
||||
}
|
||||
gsni := "*"
|
||||
if index := strings.IndexByte(sni, '.'); index != -1 {
|
||||
gsni += sni[index:]
|
||||
}
|
||||
for _, rawCertificate := range c.Certificate {
|
||||
if rawCertificate.Usage != Certificate_ENCIPHERMENT {
|
||||
continue
|
||||
}
|
||||
gsni := "*"
|
||||
if index := strings.IndexByte(sni, '.'); index != -1 {
|
||||
gsni += sni[index:]
|
||||
keyPair := rawCertificate.getX509KeyPair()
|
||||
if keyPair == nil {
|
||||
continue
|
||||
}
|
||||
for _, keyPair := range certs {
|
||||
if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
|
||||
if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
|
||||
return keyPair, nil
|
||||
}
|
||||
for _, name := range keyPair.Leaf.DNSNames {
|
||||
if name == sni || name == gsni {
|
||||
return keyPair, nil
|
||||
}
|
||||
for _, name := range keyPair.Leaf.DNSNames {
|
||||
if name == sni || name == gsni {
|
||||
return keyPair, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if rejectUnknownSNI {
|
||||
return nil, errNoCertificates
|
||||
}
|
||||
return certs[0], nil
|
||||
}
|
||||
if c.RejectUnknownSni {
|
||||
return nil, errNoCertificates
|
||||
}
|
||||
return defaultCert, nil
|
||||
}
|
||||
|
||||
func (c *Config) parseServerName() string {
|
||||
@@ -408,7 +366,12 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
|
||||
if len(caCerts) > 0 {
|
||||
config.GetCertificate = getGetCertificateFunc(config, caCerts)
|
||||
} else {
|
||||
config.GetCertificate = getNewGetCertificateFunc(c.BuildCertificates(), c.RejectUnknownSni)
|
||||
for _, cert := range c.Certificate {
|
||||
if cert.Usage == Certificate_ENCIPHERMENT {
|
||||
setupHotReload(cert)
|
||||
}
|
||||
}
|
||||
config.GetCertificate = c.getCertificate
|
||||
}
|
||||
|
||||
if sn := c.parseServerName(); len(sn) > 0 {
|
||||
|
||||
@@ -86,8 +86,12 @@ type Certificate struct {
|
||||
// If true, one-Time Loading
|
||||
OneTimeLoading bool `protobuf:"varint,7,opt,name=One_time_loading,json=OneTimeLoading,proto3" json:"One_time_loading,omitempty"`
|
||||
BuildChain bool `protobuf:"varint,8,opt,name=build_chain,json=buildChain,proto3" json:"build_chain,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
// Abuse proto data for storage hotreload data
|
||||
ParsedCache []byte `protobuf:"bytes,9,opt,name=parsed_cache,json=parsedCache,proto3" json:"parsed_cache,omitempty"`
|
||||
OcspData []byte `protobuf:"bytes,10,opt,name=ocsp_data,json=ocspData,proto3" json:"ocsp_data,omitempty"`
|
||||
LastReload int64 `protobuf:"varint,11,opt,name=last_reload,json=lastReload,proto3" json:"last_reload,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *Certificate) Reset() {
|
||||
@@ -176,6 +180,27 @@ func (x *Certificate) GetBuildChain() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *Certificate) GetParsedCache() []byte {
|
||||
if x != nil {
|
||||
return x.ParsedCache
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Certificate) GetOcspData() []byte {
|
||||
if x != nil {
|
||||
return x.OcspData
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Certificate) GetLastReload() int64 {
|
||||
if x != nil {
|
||||
return x.LastReload
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// List of certificates to be served on server.
|
||||
@@ -363,7 +388,7 @@ var File_transport_internet_tls_config_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_transport_internet_tls_config_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\x83\x03\n" +
|
||||
"#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\xe4\x03\n" +
|
||||
"\vCertificate\x12 \n" +
|
||||
"\vcertificate\x18\x01 \x01(\fR\vcertificate\x12\x10\n" +
|
||||
"\x03key\x18\x02 \x01(\fR\x03key\x12D\n" +
|
||||
@@ -373,7 +398,12 @@ const file_transport_internet_tls_config_proto_rawDesc = "" +
|
||||
"\bkey_path\x18\x06 \x01(\tR\akeyPath\x12(\n" +
|
||||
"\x10One_time_loading\x18\a \x01(\bR\x0eOneTimeLoading\x12\x1f\n" +
|
||||
"\vbuild_chain\x18\b \x01(\bR\n" +
|
||||
"buildChain\"D\n" +
|
||||
"buildChain\x12!\n" +
|
||||
"\fparsed_cache\x18\t \x01(\fR\vparsedCache\x12\x1b\n" +
|
||||
"\tocsp_data\x18\n" +
|
||||
" \x01(\fR\bocspData\x12\x1f\n" +
|
||||
"\vlast_reload\x18\v \x01(\x03R\n" +
|
||||
"lastReload\"D\n" +
|
||||
"\x05Usage\x12\x10\n" +
|
||||
"\fENCIPHERMENT\x10\x00\x12\x14\n" +
|
||||
"\x10AUTHORITY_VERIFY\x10\x01\x12\x13\n" +
|
||||
|
||||
@@ -35,6 +35,11 @@ message Certificate {
|
||||
bool One_time_loading = 7;
|
||||
|
||||
bool build_chain = 8;
|
||||
|
||||
// Abuse proto data for storage hotreload data
|
||||
bytes parsed_cache = 9;
|
||||
bytes ocsp_data = 10;
|
||||
int64 last_reload = 11;
|
||||
}
|
||||
|
||||
message Config {
|
||||
|
||||
@@ -157,7 +157,7 @@ func QueryRecord(domain string, server string, sockopt *internet.SocketConfig) (
|
||||
// If expire is zero value, it means we are in initial state, wait for the query to finish
|
||||
// otherwise return old value immediately and update in a goroutine
|
||||
// but if the cache is too old, wait for update
|
||||
if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*4).Before(time.Now()) {
|
||||
if configRecord.expire.Equal(time.Time{}) || configRecord.expire.Add(time.Hour*4).Before(time.Now()) {
|
||||
return echConfigCache.Update(domain, server, false, sockopt)
|
||||
} else {
|
||||
// If someone already acquired the lock, it means it is updating, do not start another update goroutine
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/ocsp"
|
||||
"github.com/xtls/xray-core/common/platform/filesystem"
|
||||
"github.com/xtls/xray-core/common/utils"
|
||||
)
|
||||
|
||||
var certsCache = utils.NewWeakCacheMap[uintptr, Certificate]()
|
||||
|
||||
var startHotReload sync.Once
|
||||
|
||||
func setupHotReload(entry *Certificate) {
|
||||
startHotReload.Do(func() {
|
||||
go handleHotReload()
|
||||
})
|
||||
// ensure the cache before use
|
||||
entry.getX509KeyPair()
|
||||
if entry.OneTimeLoading {
|
||||
return
|
||||
}
|
||||
uptr := uintptr(unsafe.Pointer(entry))
|
||||
if _, ok := certsCache.Load(uptr); !ok {
|
||||
certsCache.Store(uptr, entry)
|
||||
}
|
||||
}
|
||||
|
||||
func handleHotReload() {
|
||||
// should be enough?
|
||||
t := time.NewTicker(600 * time.Second)
|
||||
for {
|
||||
certsCache.Range(updateCert)
|
||||
<-t.C
|
||||
}
|
||||
}
|
||||
|
||||
func updateCert(_ uintptr, entry *Certificate) bool {
|
||||
reloadInterval := int64(entry.OcspStapling)
|
||||
if reloadInterval <= 0 {
|
||||
reloadInterval = 3600
|
||||
}
|
||||
if entry.LastReload+reloadInterval >= time.Now().Unix() {
|
||||
return true
|
||||
} else {
|
||||
entry.LastReload = time.Now().Unix()
|
||||
}
|
||||
if entry.CertificatePath != "" && entry.KeyPath != "" {
|
||||
newCert, err := filesystem.ReadCert(entry.CertificatePath)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
|
||||
return true
|
||||
}
|
||||
newKey, err := filesystem.ReadCert(entry.KeyPath)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(context.Background(), err, "failed to parse key")
|
||||
return true
|
||||
}
|
||||
if string(newCert) != string(entry.Certificate) || string(newKey) != string(entry.Key) {
|
||||
entry.Certificate = newCert
|
||||
entry.Key = newKey
|
||||
}
|
||||
}
|
||||
entry.parseX509KeyPair()
|
||||
if entry.OcspStapling > 0 {
|
||||
keyPair := entry.getX509KeyPair()
|
||||
if keyPair == nil {
|
||||
return true
|
||||
}
|
||||
if newOCSPData, err := ocsp.GetOCSPForCert(keyPair.Certificate); err != nil {
|
||||
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
|
||||
} else if !slices.Equal(newOCSPData, entry.OcspData) {
|
||||
entry.OcspData = newOCSPData
|
||||
}
|
||||
entry.parseX509KeyPair()
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user