From 676212c789d614070db8ba4e38dfe1a4917ecaf2 Mon Sep 17 00:00:00 2001 From: Fangliding Date: Sun, 21 Jun 2026 19:30:57 +0800 Subject: [PATCH] Refine proto usage --- transport/internet/tls/config.go | 33 ++++++++++++++++++------ transport/internet/tls/config.pb.go | 35 ++++++-------------------- transport/internet/tls/config.proto | 6 ++--- transport/internet/tls/hot_reloader.go | 9 ++++--- 4 files changed, 41 insertions(+), 42 deletions(-) diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 0f5a0246f..b6d0e7bcd 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -182,8 +182,27 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli } } -// atomic.Pointer must be exactly one pointer word -var _ [unsafe.Sizeof(unsafe.Pointer(nil))]byte = [unsafe.Sizeof(atomic.Pointer[tls.Certificate]{})]byte{} +type extraProtoCertData struct { + parsed atomic.Pointer[tls.Certificate] + ocspData atomic.Pointer[[]byte] + lastReload int64 +} + +// atomic.Pointer must be exactly one pointer word for the ParsedCache overlay below. +var _ [unsafe.Sizeof(unsafe.Pointer(nil))]byte = [unsafe.Sizeof(atomic.Pointer[extraProtoCertData]{})]byte{} + +func (c *Certificate) extraData() *extraProtoCertData { + // wtf is this + slot := (*atomic.Pointer[extraProtoCertData])(unsafe.Pointer(&c.ExtraData)) + if s := slot.Load(); s != nil { + return s + } + s := &extraProtoCertData{} + if slot.CompareAndSwap(nil, s) { + return s + } + return slot.Load() +} func (c *Certificate) parseX509KeyPair() *tls.Certificate { keyPair, err := tls.X509KeyPair(c.Certificate, c.Key) @@ -196,16 +215,16 @@ func (c *Certificate) parseX509KeyPair() *tls.Certificate { errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate") return nil } - if len(c.OcspData) > 0 { - keyPair.OCSPStaple = c.OcspData + st := c.extraData() + if OCSPData := st.ocspData.Load(); OCSPData != nil { + keyPair.OCSPStaple = *OCSPData } - // wtf is this - (*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Store(&keyPair) + st.parsed.Store(&keyPair) return &keyPair } func (c *Certificate) getX509KeyPair() *tls.Certificate { - if keyPair := (*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Load(); keyPair != nil { + if keyPair := c.extraData().parsed.Load(); keyPair != nil { return keyPair } return c.parseX509KeyPair() diff --git a/transport/internet/tls/config.pb.go b/transport/internet/tls/config.pb.go index 5b727553e..29ad390f3 100644 --- a/transport/internet/tls/config.pb.go +++ b/transport/internet/tls/config.pb.go @@ -92,10 +92,8 @@ type Certificate struct { // If true, one-Time Loading OneTimeLoading bool `protobuf:"varint,7,opt,name=One_time_loading,json=OneTimeLoading,proto3" json:"One_time_loading,omitempty"` BuildChain bool `protobuf:"varint,8,opt,name=build_chain,json=buildChain,proto3" json:"build_chain,omitempty"` - // Abuse proto data for storage hotreload data - ParsedCache []byte `protobuf:"bytes,9,opt,name=parsed_cache,json=parsedCache,proto3" json:"parsed_cache,omitempty"` - OcspData []byte `protobuf:"bytes,10,opt,name=ocsp_data,json=ocspData,proto3" json:"ocsp_data,omitempty"` - LastReload int64 `protobuf:"varint,11,opt,name=last_reload,json=lastReload,proto3" json:"last_reload,omitempty"` + // Abused proto data to storage some runtime data + ExtraData []byte `protobuf:"bytes,9,opt,name=extra_data,json=extraData,proto3" json:"extra_data,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -186,27 +184,13 @@ func (x *Certificate) GetBuildChain() bool { return false } -func (x *Certificate) GetParsedCache() []byte { +func (x *Certificate) GetExtraData() []byte { if x != nil { - return x.ParsedCache + return x.ExtraData } return nil } -func (x *Certificate) GetOcspData() []byte { - if x != nil { - return x.OcspData - } - return nil -} - -func (x *Certificate) GetLastReload() int64 { - if x != nil { - return x.LastReload - } - return 0 -} - type Config struct { state protoimpl.MessageState `protogen:"open.v1"` // List of certificates to be served on server. @@ -402,7 +386,7 @@ var File_transport_internet_tls_config_proto protoreflect.FileDescriptor const file_transport_internet_tls_config_proto_rawDesc = "" + "\n" + - "#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\x8e\x04\n" + + "#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\xcc\x03\n" + "\vCertificate\x12 \n" + "\vcertificate\x18\x01 \x01(\fR\vcertificate\x12\x10\n" + "\x03key\x18\x02 \x01(\fR\x03key\x12D\n" + @@ -412,12 +396,9 @@ const file_transport_internet_tls_config_proto_rawDesc = "" + "\bkey_path\x18\x06 \x01(\tR\akeyPath\x12(\n" + "\x10One_time_loading\x18\a \x01(\bR\x0eOneTimeLoading\x12\x1f\n" + "\vbuild_chain\x18\b \x01(\bR\n" + - "buildChain\x12!\n" + - "\fparsed_cache\x18\t \x01(\fR\vparsedCache\x12\x1b\n" + - "\tocsp_data\x18\n" + - " \x01(\fR\bocspData\x12\x1f\n" + - "\vlast_reload\x18\v \x01(\x03R\n" + - "lastReload\"n\n" + + "buildChain\x12\x1d\n" + + "\n" + + "extra_data\x18\t \x01(\fR\textraData\"n\n" + "\x05Usage\x12\x10\n" + "\fENCIPHERMENT\x10\x00\x12\x14\n" + "\x10AUTHORITY_VERIFY\x10\x01\x12\x13\n" + diff --git a/transport/internet/tls/config.proto b/transport/internet/tls/config.proto index fd5f38807..bc3e5df54 100644 --- a/transport/internet/tls/config.proto +++ b/transport/internet/tls/config.proto @@ -38,10 +38,8 @@ message Certificate { bool build_chain = 8; - // Abuse proto data for storage hotreload data - bytes parsed_cache = 9; - bytes ocsp_data = 10; - int64 last_reload = 11; + // Abused proto data to storage some runtime data + bytes extra_data = 9; } message Config { diff --git a/transport/internet/tls/hot_reloader.go b/transport/internet/tls/hot_reloader.go index 0ca60fb05..a154fe5ef 100644 --- a/transport/internet/tls/hot_reloader.go +++ b/transport/internet/tls/hot_reloader.go @@ -42,14 +42,15 @@ func handleHotReload() { } func updateCert(_ uintptr, entry *Certificate) bool { + extraData := entry.extraData() reloadInterval := int64(entry.OcspStapling) if reloadInterval <= 0 { reloadInterval = 3600 } - if entry.LastReload+reloadInterval >= time.Now().Unix() { + if extraData.lastReload+reloadInterval >= time.Now().Unix() { return true } else { - entry.LastReload = time.Now().Unix() + extraData.lastReload = time.Now().Unix() } if entry.CertificatePath != "" && entry.KeyPath != "" { newCert, err := filesystem.ReadCert(entry.CertificatePath) @@ -75,8 +76,8 @@ func updateCert(_ uintptr, entry *Certificate) bool { } if newOCSPData, err := ocsp.GetOCSPForCert(keyPair.Certificate); err != nil { errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP") - } else if !slices.Equal(newOCSPData, entry.OcspData) { - entry.OcspData = newOCSPData + } else if OCSPData := extraData.ocspData.Load(); OCSPData == nil || !slices.Equal(newOCSPData, *OCSPData) { + extraData.ocspData.Store(&newOCSPData) } entry.parseX509KeyPair() }