bytecategory
2026-05-23 17:24:37 +08:00
committed by GitHub
parent da9ba693cb
commit 359a28f876
6 changed files with 493 additions and 137 deletions
+24 -31
View File
@@ -10,7 +10,6 @@ import (
"io"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -40,6 +39,7 @@ type xdnsConnClient struct {
net.PacketConn
resolverAddrs []*net.UDPAddr
resolverTypes []uint16
resolverIdx uint32
resolverSend map[string]*atomic.Uint32
@@ -61,17 +61,15 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
var domains []Name
var servers []string
var resolverTypes []uint16
for _, rs := range c.Resolvers {
parts := strings.Split(rs, "+udp://")
if len(parts) != 2 {
return nil, errors.New("invalid resolvers")
}
domain, err := ParseName(parts[0])
domain, server, resolverType, err := parseResolver(rs)
if err != nil {
return nil, err
return nil, errors.New("invalid resolvers").Base(err)
}
domains = append(domains, domain)
servers = append(servers, parts[1])
servers = append(servers, server)
resolverTypes = append(resolverTypes, resolverType)
}
var resolverAddrs []*net.UDPAddr
@@ -98,6 +96,7 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
PacketConn: raw,
resolverAddrs: resolverAddrs,
resolverTypes: resolverTypes,
resolverIdx: 0,
resolverSend: resolverSend,
@@ -216,7 +215,7 @@ func (c *xdnsConnClient) sendLoop() {
default:
}
} else {
encoded, _ := encode(nil, c.clientID, c.domains[c.resolverIdx])
encoded, _ := encode(nil, c.clientID, c.domains[c.resolverIdx], c.resolverTypes[c.resolverIdx])
p = &packet{
p: encoded,
}
@@ -276,7 +275,8 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, io.ErrClosedPipe
}
encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverAddrs))])
idx := c.resolverIdx % uint32(len(c.resolverAddrs))
encoded, err := encode(p, c.clientID, c.domains[idx], c.resolverTypes[idx])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p))
return 0, nil
@@ -299,7 +299,7 @@ func (c *xdnsConnClient) Close() error {
return c.PacketConn.Close()
}
func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
func encode(p []byte, clientID []byte, domain Name, qtype uint16) ([]byte, error) {
var decoded []byte
{
if len(p) >= 224 {
@@ -338,7 +338,7 @@ func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
Question: []Question{
{
Name: name,
Type: RRTypeTXT,
Type: qtype,
Class: ClassIN,
},
},
@@ -396,29 +396,22 @@ func dnsResponsePayload(resp *Message, domains []Name) []byte {
return nil
}
if len(resp.Answer) != 1 {
if len(resp.Answer) == 0 {
return nil
}
answer := resp.Answer[0]
var ok bool
for _, domain := range domains {
_, ok = answer.Name.TrimSuffix(domain)
if ok {
break
for _, answer := range resp.Answer {
var ok bool
for _, domain := range domains {
_, ok = answer.Name.TrimSuffix(domain)
if ok {
break
}
}
if !ok {
return nil
}
}
if !ok {
return nil
}
if answer.Type != RRTypeTXT {
return nil
}
payload, err := DecodeRDataTXT(answer.Data)
if err != nil {
return nil
}
return payload
return decodeResponsePayload(resp.Answer)
}
+6
View File
@@ -45,8 +45,14 @@ var (
)
const (
// https://tools.ietf.org/html/rfc1035#section-3.2.2
RRTypeA = 1
// https://tools.ietf.org/html/rfc1035#section-3.2.2
RRTypeCNAME = 5
// https://tools.ietf.org/html/rfc1035#section-3.2.2
RRTypeTXT = 16
// https://tools.ietf.org/html/rfc3596#section-2.1
RRTypeAAAA = 28
// https://tools.ietf.org/html/rfc6891#section-6.1.1
RRTypeOPT = 41
@@ -593,3 +593,113 @@ func TestRDataTXTRoundTrip(t *testing.T) {
}
}
}
func TestIPAnswerPayloadRoundTrip(t *testing.T) {
for _, rrType := range []uint16{RRTypeA, RRTypeAAAA} {
for _, payload := range [][]byte{
{},
{0x01},
[]byte("hello world"),
bytes.Repeat([]byte{0xab}, payloadChunkSizeForType(rrType)*3+1),
} {
question := Question{
Name: mustParseName("example.com"),
Type: rrType,
Class: ClassIN,
}
answers, err := answersForPayload(question, responseTTL, payload)
if err != nil {
t.Fatalf("answersForPayload(%d) err = %v", rrType, err)
}
if len(answers) > 1 {
answers[0], answers[len(answers)-1] = answers[len(answers)-1], answers[0]
}
decoded := decodeResponsePayload(answers)
if !bytes.Equal(decoded, payload) {
t.Fatalf("rrType=%d decoded %x want %x", rrType, decoded, payload)
}
}
}
}
func TestParseResolver(t *testing.T) {
tests := []struct {
resolver string
rrType uint16
}{
{"example.com+udp://1.1.1.1:53", RRTypeTXT},
{"example.com:txt+udp://1.1.1.1:53", RRTypeTXT},
{"example.com:a+udp://1.1.1.1:53", RRTypeA},
{"example.com:aaaa+udp://1.1.1.1:53", RRTypeAAAA},
}
for _, test := range tests {
domain, server, rrType, err := parseResolver(test.resolver)
if err != nil {
t.Fatalf("parseResolver(%q) err = %v", test.resolver, err)
}
if domain.String() != "example.com" || server != "1.1.1.1:53" || rrType != test.rrType {
t.Fatalf("parseResolver(%q) = (%q, %q, %d)", test.resolver, domain.String(), server, rrType)
}
}
}
func TestParseDomainSpec(t *testing.T) {
tests := []struct {
spec string
def string
rrType uint16
wantErr bool
}{
{"example.com", "", 0, false},
{"example.com", "txt", RRTypeTXT, false},
{"example.com:a", "", RRTypeA, false},
{"example.com:aaaa", "", RRTypeAAAA, false},
{"example.com:doh", "", 0, true},
}
for _, test := range tests {
got, err := parseDomainSpec(test.spec, test.def)
if test.wantErr {
if err == nil {
t.Fatalf("parseDomainSpec(%q, %q) err = nil", test.spec, test.def)
}
continue
}
if err != nil {
t.Fatalf("parseDomainSpec(%q, %q) err = %v", test.spec, test.def, err)
}
if got.name.String() != "example.com" || got.rrType != test.rrType {
t.Fatalf("parseDomainSpec(%q, %q) = (%q, %d)", test.spec, test.def, got.name.String(), got.rrType)
}
}
}
func TestResponseForMethodRestriction(t *testing.T) {
query := &Message{
ID: 1,
Flags: 0x0100,
Question: []Question{{
Name: mustParseName("abc.example.com"),
Type: RRTypeTXT,
Class: ClassIN,
}},
Additional: []RR{{
Name: Name{},
Type: RRTypeOPT,
Class: 4096,
}},
}
resp, _ := responseFor(query, []domainSpec{{name: mustParseName("example.com"), rrType: RRTypeA}})
if resp == nil || resp.Rcode() != RcodeNameError {
t.Fatalf("responseFor method restriction rcode = %v", resp)
}
resp, _ = responseFor(query, []domainSpec{{name: mustParseName("example.com")}})
if resp == nil || resp.Rcode() != RcodeNoError {
t.Fatalf("responseFor unrestricted rcode = %v", resp)
}
}
@@ -0,0 +1,226 @@
package xdns
import "bytes"
const ipRecordHeaderSize = 2
func maxEncodedPayloadForType(rrType uint16) int {
switch rrType {
case RRTypeA:
return maxEncodedPayloadA
case RRTypeAAAA:
return maxEncodedPayloadAAAA
default:
return maxEncodedPayloadTXT
}
}
func rrDataSizeForType(rrType uint16) int {
switch rrType {
case RRTypeA:
return 4
case RRTypeAAAA:
return 16
default:
return 0
}
}
func payloadChunkSizeForType(rrType uint16) int {
size := rrDataSizeForType(rrType)
if size <= ipRecordHeaderSize {
return 0
}
return size - ipRecordHeaderSize
}
func answersForPayload(question Question, ttl uint32, payload []byte) ([]RR, error) {
switch question.Type {
case RRTypeTXT:
return []RR{
{
Name: question.Name,
Type: question.Type,
Class: question.Class,
TTL: ttl,
Data: EncodeRDataTXT(payload),
},
}, nil
case RRTypeA, RRTypeAAAA:
return ipAnswersForPayload(question, ttl, payload)
default:
return nil, ErrIntegerOverflow
}
}
func ipAnswersForPayload(question Question, ttl uint32, payload []byte) ([]RR, error) {
chunkSize := payloadChunkSizeForType(question.Type)
rrDataSize := rrDataSizeForType(question.Type)
if chunkSize == 0 || rrDataSize == 0 {
return nil, ErrIntegerOverflow
}
numRecords := 1
if len(payload) > 0 {
numRecords = (len(payload) + chunkSize - 1) / chunkSize
}
if numRecords > 256 {
return nil, ErrIntegerOverflow
}
answers := make([]RR, 0, numRecords)
for i := 0; i < numRecords; i++ {
offset := i * chunkSize
n := len(payload) - offset
if n < 0 {
n = 0
}
if n > chunkSize {
n = chunkSize
}
data := make([]byte, rrDataSize)
data[0] = byte(i)
data[1] = byte(n)
copy(data[ipRecordHeaderSize:], payload[offset:offset+n])
answers = append(answers, RR{
Name: question.Name,
Type: question.Type,
Class: question.Class,
TTL: ttl,
Data: data,
})
}
return answers, nil
}
func decodeResponsePayload(answers []RR) []byte {
if len(answers) == 0 {
return nil
}
switch answers[0].Type {
case RRTypeTXT:
if len(answers) != 1 {
return nil
}
payload, err := DecodeRDataTXT(answers[0].Data)
if err != nil {
return nil
}
return payload
case RRTypeA, RRTypeAAAA:
return decodeIPAnswerPayload(answers, answers[0].Type)
default:
return nil
}
}
func decodeIPAnswerPayload(answers []RR, rrType uint16) []byte {
chunkSize := payloadChunkSizeForType(rrType)
rrDataSize := rrDataSizeForType(rrType)
if chunkSize == 0 || rrDataSize == 0 || len(answers) > 256 {
return nil
}
parts := make([][]byte, len(answers))
for _, answer := range answers {
if answer.Type != rrType || len(answer.Data) != rrDataSize {
return nil
}
idx := int(answer.Data[0])
n := int(answer.Data[1])
if idx >= len(answers) || n > chunkSize || parts[idx] != nil {
return nil
}
part := make([]byte, n)
copy(part, answer.Data[ipRecordHeaderSize:ipRecordHeaderSize+n])
parts[idx] = part
}
var payload bytes.Buffer
for _, part := range parts {
if part == nil {
return nil
}
payload.Write(part)
}
return payload.Bytes()
}
func computeMaxEncodedPayload(limit int) int {
return computeMaxEncodedPayloadForType(limit, RRTypeTXT)
}
func computeMaxEncodedPayloadForType(limit int, rrType uint16) int {
maxLengthName, err := NewName([][]byte{
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
})
if err != nil {
panic(err)
}
{
n := 0
for _, label := range maxLengthName {
n += len(label) + 1
}
n += 1
if n != 255 {
panic("computeMaxEncodedPayload n != 255")
}
}
queryLimit := uint16(limit)
if int(queryLimit) != limit {
queryLimit = 0xffff
}
query := &Message{
Question: []Question{
{
Name: maxLengthName,
Type: rrType,
Class: ClassIN,
},
},
Additional: []RR{
{
Name: Name{},
Type: RRTypeOPT,
Class: queryLimit,
TTL: 0,
Data: []byte{},
},
},
}
resp, _ := responseFor(query, []domainSpec{{name: Name{[]byte{}}}})
low := 0
high := 32768
if chunkSize := payloadChunkSizeForType(rrType); chunkSize > 0 {
high = 256*chunkSize + 1
}
for low+1 < high {
mid := (low + high) / 2
resp.Answer, err = answersForPayload(query.Question[0], responseTTL, make([]byte, mid))
if err != nil {
panic(err)
}
buf, err := resp.WireFormat()
if err != nil {
panic(err)
}
if len(buf) <= limit {
low = mid
} else {
high = mid
}
}
return low
}
+47 -106
View File
@@ -21,8 +21,10 @@ const (
)
var (
maxUDPPayload = 1280 - 40 - 8
maxEncodedPayload = computeMaxEncodedPayload(maxUDPPayload)
maxUDPPayload = 1280 - 40 - 8
maxEncodedPayloadTXT = computeMaxEncodedPayloadForType(maxUDPPayload, RRTypeTXT)
maxEncodedPayloadA = computeMaxEncodedPayloadForType(maxUDPPayload, RRTypeA)
maxEncodedPayloadAAAA = computeMaxEncodedPayloadForType(maxUDPPayload, RRTypeAAAA)
)
func clientIDToAddr(clientID [8]byte) *net.UDPAddr {
@@ -44,15 +46,16 @@ type record struct {
}
type queue struct {
last time.Time
queue chan []byte
stash chan []byte
last time.Time
rrType uint16
queue chan []byte
stash chan []byte
}
type xdnsConnServer struct {
net.PacketConn
domains []Name
domains []domainSpec
ch chan *record
readQueue chan *packet
@@ -66,9 +69,9 @@ func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) {
if len(c.Domains) == 0 {
return nil, errors.New("empty domains")
}
domains := make([]Name, 0, len(c.Domains))
domains := make([]domainSpec, 0, len(c.Domains))
for _, domain := range c.Domains {
domain, err := ParseName(domain)
domain, err := parseDomainSpec(domain, "")
if err != nil {
return nil, err
}
@@ -234,6 +237,7 @@ func (c *xdnsConnServer) recvLoop() {
func (c *xdnsConnServer) sendLoop() {
var nextRec *record
for {
var err error
rec := nextRec
nextRec = nil
@@ -246,18 +250,8 @@ func (c *xdnsConnServer) sendLoop() {
}
if rec.Resp.Rcode() == RcodeNoError && len(rec.Resp.Question) == 1 {
rec.Resp.Answer = []RR{
{
Name: rec.Resp.Question[0].Name,
Type: rec.Resp.Question[0].Type,
Class: rec.Resp.Question[0].Class,
TTL: responseTTL,
Data: nil,
},
}
var payload bytes.Buffer
limit := maxEncodedPayload
limit := maxEncodedPayloadForType(rec.Resp.Question[0].Type)
timer := time.NewTimer(maxResponseDelay)
for {
@@ -267,6 +261,7 @@ func (c *xdnsConnServer) sendLoop() {
c.mutex.Unlock()
return
}
q.rrType = rec.Resp.Question[0].Type
c.mutex.Unlock()
var p []byte
@@ -294,7 +289,11 @@ func (c *xdnsConnServer) sendLoop() {
}
limit -= 2 + len(p)
if payload.Len() > 0 && limit < 0 {
if limit < 0 {
if payload.Len() == 0 {
errors.LogDebug(context.Background(), rec.Addr, " ", rec.ClientAddr, " xdns payload too large for rrtype ", rec.Resp.Question[0].Type, " ", len(p))
continue
}
c.stash(q, p)
break
}
@@ -308,7 +307,11 @@ func (c *xdnsConnServer) sendLoop() {
}
timer.Stop()
rec.Resp.Answer[0].Data = EncodeRDataTXT(payload.Bytes())
rec.Resp.Answer, err = answersForPayload(rec.Resp.Question[0], responseTTL, payload.Bytes())
if err != nil {
errors.LogDebug(context.Background(), rec.Addr, " ", rec.ClientAddr, " xdns encode err ", err)
continue
}
}
buf, err := rec.Resp.WireFormat()
@@ -349,11 +352,6 @@ func (c *xdnsConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
}
func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if len(p)+2 > maxEncodedPayload {
errors.LogDebug(context.Background(), addr, " mask write err short write ", len(p), "+2 > ", maxEncodedPayload)
return 0, nil
}
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -361,6 +359,14 @@ func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if q == nil {
return 0, io.ErrClosedPipe
}
limit := maxEncodedPayloadForType(q.rrType)
if q.rrType == 0 {
limit = maxEncodedPayloadTXT
}
if len(p)+2 > limit {
errors.LogDebug(context.Background(), addr, " mask write err short write ", len(p), "+2 > ", limit)
return 0, nil
}
buf := make([]byte, len(p))
copy(buf, p)
@@ -406,7 +412,7 @@ func nextPacketServer(r *bytes.Reader) ([]byte, error) {
}
}
func responseFor(query *Message, domains []Name) (*Message, []byte) {
func responseFor(query *Message, domains []domainSpec) (*Message, []byte) {
resp := &Message{
ID: query.ID,
Flags: 0x8000,
@@ -454,11 +460,15 @@ func responseFor(query *Message, domains []Name) (*Message, []byte) {
}
question := query.Question[0]
var prefix Name
var ok bool
var (
prefix Name
ok bool
match domainSpec
)
for _, domain := range domains {
prefix, ok = question.Name.TrimSuffix(domain)
prefix, ok = question.Name.TrimSuffix(domain.name)
if ok {
match = domain
break
}
}
@@ -473,7 +483,13 @@ func responseFor(query *Message, domains []Name) (*Message, []byte) {
return resp, nil
}
if question.Type != RRTypeTXT {
switch question.Type {
case RRTypeTXT, RRTypeA, RRTypeAAAA:
default:
resp.Flags |= RcodeNameError
return resp, nil
}
if match.rrType != 0 && question.Type != match.rrType {
resp.Flags |= RcodeNameError
return resp, nil
}
@@ -494,78 +510,3 @@ func responseFor(query *Message, domains []Name) (*Message, []byte) {
return resp, payload
}
func computeMaxEncodedPayload(limit int) int {
maxLengthName, err := NewName([][]byte{
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
})
if err != nil {
panic(err)
}
{
n := 0
for _, label := range maxLengthName {
n += len(label) + 1
}
n += 1
if n != 255 {
panic("computeMaxEncodedPayload n != 255")
}
}
queryLimit := uint16(limit)
if int(queryLimit) != limit {
queryLimit = 0xffff
}
query := &Message{
Question: []Question{
{
Name: maxLengthName,
Type: RRTypeTXT,
Class: RRTypeTXT,
},
},
Additional: []RR{
{
Name: Name{},
Type: RRTypeOPT,
Class: queryLimit,
TTL: 0,
Data: []byte{},
},
},
}
resp, _ := responseFor(query, []Name{[][]byte{}})
resp.Answer = []RR{
{
Name: query.Question[0].Name,
Type: query.Question[0].Type,
Class: query.Question[0].Class,
TTL: responseTTL,
Data: nil,
},
}
low := 0
high := 32768
for low+1 < high {
mid := (low + high) / 2
resp.Answer[0].Data = EncodeRDataTXT(make([]byte, mid))
buf, err := resp.WireFormat()
if err != nil {
panic(err)
}
if len(buf) <= limit {
low = mid
} else {
high = mid
}
}
return low
}
+80
View File
@@ -0,0 +1,80 @@
package xdns
import (
"strings"
"github.com/xtls/xray-core/common/errors"
)
type domainSpec struct {
name Name
rrType uint16
}
func rrTypeFromMethod(method string) (uint16, error) {
switch strings.ToLower(method) {
case "", "txt":
return RRTypeTXT, nil
case "a":
return RRTypeA, nil
case "aaaa":
return RRTypeAAAA, nil
default:
return 0, errors.New("unsupported method")
}
}
func parseDomainSpec(s string, defaultMethod string) (domainSpec, error) {
domainPart := s
method := ""
hasMethod := false
if i := strings.LastIndex(s, ":"); i >= 0 {
domainPart = s[:i]
method = s[i+1:]
hasMethod = true
} else if defaultMethod != "" {
method = defaultMethod
hasMethod = true
}
if domainPart == "" {
return domainSpec{}, errors.New("empty domain")
}
name, err := ParseName(domainPart)
if err != nil {
return domainSpec{}, err
}
rrType := uint16(0)
if hasMethod {
var err error
rrType, err = rrTypeFromMethod(method)
if err != nil {
return domainSpec{}, err
}
}
return domainSpec{
name: name,
rrType: rrType,
}, nil
}
func parseResolver(s string) (Name, string, uint16, error) {
head, server, ok := strings.Cut(s, "+udp://")
if !ok {
return nil, "", 0, errors.New("invalid resolver scheme")
}
if server == "" {
return nil, "", 0, errors.New("empty resolver server")
}
spec, err := parseDomainSpec(head, "txt")
if err != nil {
return nil, "", 0, err
}
return spec.name, server, spec.rrType, nil
}