refactor(dns): enhance cache safety, optimize performance, and refactor query logic (#5248)

This commit is contained in:
Meow
2025-11-21 10:45:09 +08:00
committed by GitHub
parent 27ad487545
commit b40bf56e4e
7 changed files with 454 additions and 325 deletions
+16 -54
View File
@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/tls"
go_errors "errors"
"fmt"
"io"
"net/http"
@@ -121,10 +120,16 @@ func (s *DoHNameServer) newReqID() uint16 {
return 0
}
func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
errors.LogInfo(ctx, s.Name(), " querying: ", domain)
// getCacheController implements CachedNameserver.
func (s *DoHNameServer) getCacheController() *CacheController {
return s.cacheController
}
if s.Name()+"." == "DOH//"+domain {
// sendQuery implements CachedNameserver.
func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) {
errors.LogInfo(ctx, s.Name(), " querying: ", fqdn)
if s.Name()+"." == "DOH//"+fqdn {
errors.LogError(ctx, s.Name(), " tries to resolve itself! Use IP or set \"hosts\" instead.")
noResponseErrCh <- errors.New("tries to resolve itself!", s.Name())
return
@@ -132,7 +137,7 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er
// As we don't want our traffic pattern looks like DoH, we use Random-Length Padding instead of Block-Length Padding recommended in RFC 8467
// Although DoH server like 1.1.1.1 will pad the response to Block-Length 468, at least it is better than no padding for response at all
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300))))
reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300))))
var deadline time.Time
if d, ok := ctx.Deadline(); ok {
@@ -166,23 +171,23 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er
b, err := dns.PackMessage(r.msg)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to pack dns query for ", domain)
errors.LogErrorInner(ctx, err, "failed to pack dns query for ", fqdn)
noResponseErrCh <- err
return
}
resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
if err != nil {
errors.LogErrorInner(ctx, err, "failed to retrieve response for ", domain)
errors.LogErrorInner(ctx, err, "failed to retrieve response for ", fqdn)
noResponseErrCh <- err
return
}
rec, err := parseResponse(resp)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", domain)
errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", fqdn)
noResponseErrCh <- err
return
}
s.cacheController.updateIP(r, rec)
s.cacheController.updateRecord(r, rec)
}(req)
}
}
@@ -216,49 +221,6 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte,
}
// QueryIP implements Server.
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { // nolint: dupl
fqdn := Fqdn(domain)
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
defer closeSubscribers(sub4, sub6)
if s.cacheController.disableCache {
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
} else {
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
if !go_errors.Is(err, errRecordNotFound) {
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
return ips, ttl, err
}
}
noResponseErrCh := make(chan error, 2)
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
start := time.Now()
if sub4 != nil {
select {
case <-ctx.Done():
return nil, 0, ctx.Err()
case err := <-noResponseErrCh:
return nil, 0, err
case <-sub4.Wait():
sub4.Close()
}
}
if sub6 != nil {
select {
case <-ctx.Done():
return nil, 0, ctx.Err()
case err := <-noResponseErrCh:
return nil, 0, err
case <-sub6.Wait():
sub6.Close()
}
}
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
return ips, ttl, err
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
return queryIP(ctx, s, domain, option)
}