Refine proto usage

This commit is contained in:
Fangliding
2026-06-21 19:30:57 +08:00
parent 1ca32a7af8
commit 676212c789
4 changed files with 41 additions and 42 deletions
+26 -7
View File
@@ -182,8 +182,27 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
}
}
// atomic.Pointer must be exactly one pointer word
var _ [unsafe.Sizeof(unsafe.Pointer(nil))]byte = [unsafe.Sizeof(atomic.Pointer[tls.Certificate]{})]byte{}
type extraProtoCertData struct {
parsed atomic.Pointer[tls.Certificate]
ocspData atomic.Pointer[[]byte]
lastReload int64
}
// atomic.Pointer must be exactly one pointer word for the ParsedCache overlay below.
var _ [unsafe.Sizeof(unsafe.Pointer(nil))]byte = [unsafe.Sizeof(atomic.Pointer[extraProtoCertData]{})]byte{}
func (c *Certificate) extraData() *extraProtoCertData {
// wtf is this
slot := (*atomic.Pointer[extraProtoCertData])(unsafe.Pointer(&c.ExtraData))
if s := slot.Load(); s != nil {
return s
}
s := &extraProtoCertData{}
if slot.CompareAndSwap(nil, s) {
return s
}
return slot.Load()
}
func (c *Certificate) parseX509KeyPair() *tls.Certificate {
keyPair, err := tls.X509KeyPair(c.Certificate, c.Key)
@@ -196,16 +215,16 @@ func (c *Certificate) parseX509KeyPair() *tls.Certificate {
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
return nil
}
if len(c.OcspData) > 0 {
keyPair.OCSPStaple = c.OcspData
st := c.extraData()
if OCSPData := st.ocspData.Load(); OCSPData != nil {
keyPair.OCSPStaple = *OCSPData
}
// wtf is this
(*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Store(&keyPair)
st.parsed.Store(&keyPair)
return &keyPair
}
func (c *Certificate) getX509KeyPair() *tls.Certificate {
if keyPair := (*atomic.Pointer[tls.Certificate])(unsafe.Pointer(&c.ParsedCache)).Load(); keyPair != nil {
if keyPair := c.extraData().parsed.Load(); keyPair != nil {
return keyPair
}
return c.parseX509KeyPair()
+8 -27
View File
@@ -92,10 +92,8 @@ type Certificate struct {
// If true, one-Time Loading
OneTimeLoading bool `protobuf:"varint,7,opt,name=One_time_loading,json=OneTimeLoading,proto3" json:"One_time_loading,omitempty"`
BuildChain bool `protobuf:"varint,8,opt,name=build_chain,json=buildChain,proto3" json:"build_chain,omitempty"`
// Abuse proto data for storage hotreload data
ParsedCache []byte `protobuf:"bytes,9,opt,name=parsed_cache,json=parsedCache,proto3" json:"parsed_cache,omitempty"`
OcspData []byte `protobuf:"bytes,10,opt,name=ocsp_data,json=ocspData,proto3" json:"ocsp_data,omitempty"`
LastReload int64 `protobuf:"varint,11,opt,name=last_reload,json=lastReload,proto3" json:"last_reload,omitempty"`
// Abused proto data to storage some runtime data
ExtraData []byte `protobuf:"bytes,9,opt,name=extra_data,json=extraData,proto3" json:"extra_data,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -186,27 +184,13 @@ func (x *Certificate) GetBuildChain() bool {
return false
}
func (x *Certificate) GetParsedCache() []byte {
func (x *Certificate) GetExtraData() []byte {
if x != nil {
return x.ParsedCache
return x.ExtraData
}
return nil
}
func (x *Certificate) GetOcspData() []byte {
if x != nil {
return x.OcspData
}
return nil
}
func (x *Certificate) GetLastReload() int64 {
if x != nil {
return x.LastReload
}
return 0
}
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
// List of certificates to be served on server.
@@ -402,7 +386,7 @@ var File_transport_internet_tls_config_proto protoreflect.FileDescriptor
const file_transport_internet_tls_config_proto_rawDesc = "" +
"\n" +
"#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\x8e\x04\n" +
"#transport/internet/tls/config.proto\x12\x1bxray.transport.internet.tls\x1a\x1ftransport/internet/config.proto\"\xcc\x03\n" +
"\vCertificate\x12 \n" +
"\vcertificate\x18\x01 \x01(\fR\vcertificate\x12\x10\n" +
"\x03key\x18\x02 \x01(\fR\x03key\x12D\n" +
@@ -412,12 +396,9 @@ const file_transport_internet_tls_config_proto_rawDesc = "" +
"\bkey_path\x18\x06 \x01(\tR\akeyPath\x12(\n" +
"\x10One_time_loading\x18\a \x01(\bR\x0eOneTimeLoading\x12\x1f\n" +
"\vbuild_chain\x18\b \x01(\bR\n" +
"buildChain\x12!\n" +
"\fparsed_cache\x18\t \x01(\fR\vparsedCache\x12\x1b\n" +
"\tocsp_data\x18\n" +
" \x01(\fR\bocspData\x12\x1f\n" +
"\vlast_reload\x18\v \x01(\x03R\n" +
"lastReload\"n\n" +
"buildChain\x12\x1d\n" +
"\n" +
"extra_data\x18\t \x01(\fR\textraData\"n\n" +
"\x05Usage\x12\x10\n" +
"\fENCIPHERMENT\x10\x00\x12\x14\n" +
"\x10AUTHORITY_VERIFY\x10\x01\x12\x13\n" +
+2 -4
View File
@@ -38,10 +38,8 @@ message Certificate {
bool build_chain = 8;
// Abuse proto data for storage hotreload data
bytes parsed_cache = 9;
bytes ocsp_data = 10;
int64 last_reload = 11;
// Abused proto data to storage some runtime data
bytes extra_data = 9;
}
message Config {
+5 -4
View File
@@ -42,14 +42,15 @@ func handleHotReload() {
}
func updateCert(_ uintptr, entry *Certificate) bool {
extraData := entry.extraData()
reloadInterval := int64(entry.OcspStapling)
if reloadInterval <= 0 {
reloadInterval = 3600
}
if entry.LastReload+reloadInterval >= time.Now().Unix() {
if extraData.lastReload+reloadInterval >= time.Now().Unix() {
return true
} else {
entry.LastReload = time.Now().Unix()
extraData.lastReload = time.Now().Unix()
}
if entry.CertificatePath != "" && entry.KeyPath != "" {
newCert, err := filesystem.ReadCert(entry.CertificatePath)
@@ -75,8 +76,8 @@ func updateCert(_ uintptr, entry *Certificate) bool {
}
if newOCSPData, err := ocsp.GetOCSPForCert(keyPair.Certificate); err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
} else if !slices.Equal(newOCSPData, entry.OcspData) {
entry.OcspData = newOCSPData
} else if OCSPData := extraData.ocspData.Load(); OCSPData == nil || !slices.Equal(newOCSPData, *OCSPData) {
extraData.ocspData.Store(&newOCSPData)
}
entry.parseX509KeyPair()
}