mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-06-08 14:11:54 +00:00
XDNS finalmask: Support AAAA & A (#6123)
https://github.com/XTLS/Xray-core/pull/6123#issuecomment-4439994266 https://github.com/XTLS/Xray-core/pull/6123#issuecomment-4441819696 https://github.com/XTLS/Xray-core/pull/6123#issuecomment-4522352460 Example: https://github.com/XTLS/Xray-core/pull/6123#issue-4436283689 Document: https://xtls.github.io/config/transports/finalmask.html#xdns
This commit is contained in:
@@ -10,7 +10,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -40,6 +39,7 @@ type xdnsConnClient struct {
|
||||
net.PacketConn
|
||||
|
||||
resolverAddrs []*net.UDPAddr
|
||||
resolverTypes []uint16
|
||||
resolverIdx uint32
|
||||
resolverSend map[string]*atomic.Uint32
|
||||
|
||||
@@ -61,17 +61,15 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
|
||||
|
||||
var domains []Name
|
||||
var servers []string
|
||||
var resolverTypes []uint16
|
||||
for _, rs := range c.Resolvers {
|
||||
parts := strings.Split(rs, "+udp://")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid resolvers")
|
||||
}
|
||||
domain, err := ParseName(parts[0])
|
||||
domain, server, resolverType, err := parseResolver(rs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.New("invalid resolvers").Base(err)
|
||||
}
|
||||
domains = append(domains, domain)
|
||||
servers = append(servers, parts[1])
|
||||
servers = append(servers, server)
|
||||
resolverTypes = append(resolverTypes, resolverType)
|
||||
}
|
||||
|
||||
var resolverAddrs []*net.UDPAddr
|
||||
@@ -98,6 +96,7 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
|
||||
PacketConn: raw,
|
||||
|
||||
resolverAddrs: resolverAddrs,
|
||||
resolverTypes: resolverTypes,
|
||||
resolverIdx: 0,
|
||||
resolverSend: resolverSend,
|
||||
|
||||
@@ -216,7 +215,7 @@ func (c *xdnsConnClient) sendLoop() {
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
encoded, _ := encode(nil, c.clientID, c.domains[c.resolverIdx])
|
||||
encoded, _ := encode(nil, c.clientID, c.domains[c.resolverIdx], c.resolverTypes[c.resolverIdx])
|
||||
p = &packet{
|
||||
p: encoded,
|
||||
}
|
||||
@@ -276,7 +275,8 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverAddrs))])
|
||||
idx := c.resolverIdx % uint32(len(c.resolverAddrs))
|
||||
encoded, err := encode(p, c.clientID, c.domains[idx], c.resolverTypes[idx])
|
||||
if err != nil {
|
||||
errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p))
|
||||
return 0, nil
|
||||
@@ -299,7 +299,7 @@ func (c *xdnsConnClient) Close() error {
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
||||
func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
|
||||
func encode(p []byte, clientID []byte, domain Name, qtype uint16) ([]byte, error) {
|
||||
var decoded []byte
|
||||
{
|
||||
if len(p) >= 224 {
|
||||
@@ -338,7 +338,7 @@ func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
|
||||
Question: []Question{
|
||||
{
|
||||
Name: name,
|
||||
Type: RRTypeTXT,
|
||||
Type: qtype,
|
||||
Class: ClassIN,
|
||||
},
|
||||
},
|
||||
@@ -396,29 +396,22 @@ func dnsResponsePayload(resp *Message, domains []Name) []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(resp.Answer) != 1 {
|
||||
if len(resp.Answer) == 0 {
|
||||
return nil
|
||||
}
|
||||
answer := resp.Answer[0]
|
||||
|
||||
var ok bool
|
||||
for _, domain := range domains {
|
||||
_, ok = answer.Name.TrimSuffix(domain)
|
||||
if ok {
|
||||
break
|
||||
for _, answer := range resp.Answer {
|
||||
var ok bool
|
||||
for _, domain := range domains {
|
||||
_, ok = answer.Name.TrimSuffix(domain)
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if answer.Type != RRTypeTXT {
|
||||
return nil
|
||||
}
|
||||
payload, err := DecodeRDataTXT(answer.Data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return payload
|
||||
return decodeResponsePayload(resp.Answer)
|
||||
}
|
||||
|
||||
@@ -45,8 +45,14 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.2.2
|
||||
RRTypeA = 1
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.2.2
|
||||
RRTypeCNAME = 5
|
||||
// https://tools.ietf.org/html/rfc1035#section-3.2.2
|
||||
RRTypeTXT = 16
|
||||
// https://tools.ietf.org/html/rfc3596#section-2.1
|
||||
RRTypeAAAA = 28
|
||||
// https://tools.ietf.org/html/rfc6891#section-6.1.1
|
||||
RRTypeOPT = 41
|
||||
|
||||
|
||||
@@ -593,3 +593,113 @@ func TestRDataTXTRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAnswerPayloadRoundTrip(t *testing.T) {
|
||||
for _, rrType := range []uint16{RRTypeA, RRTypeAAAA} {
|
||||
for _, payload := range [][]byte{
|
||||
{},
|
||||
{0x01},
|
||||
[]byte("hello world"),
|
||||
bytes.Repeat([]byte{0xab}, payloadChunkSizeForType(rrType)*3+1),
|
||||
} {
|
||||
question := Question{
|
||||
Name: mustParseName("example.com"),
|
||||
Type: rrType,
|
||||
Class: ClassIN,
|
||||
}
|
||||
answers, err := answersForPayload(question, responseTTL, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("answersForPayload(%d) err = %v", rrType, err)
|
||||
}
|
||||
|
||||
if len(answers) > 1 {
|
||||
answers[0], answers[len(answers)-1] = answers[len(answers)-1], answers[0]
|
||||
}
|
||||
|
||||
decoded := decodeResponsePayload(answers)
|
||||
if !bytes.Equal(decoded, payload) {
|
||||
t.Fatalf("rrType=%d decoded %x want %x", rrType, decoded, payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResolver(t *testing.T) {
|
||||
tests := []struct {
|
||||
resolver string
|
||||
rrType uint16
|
||||
}{
|
||||
{"example.com+udp://1.1.1.1:53", RRTypeTXT},
|
||||
{"example.com:txt+udp://1.1.1.1:53", RRTypeTXT},
|
||||
{"example.com:a+udp://1.1.1.1:53", RRTypeA},
|
||||
{"example.com:aaaa+udp://1.1.1.1:53", RRTypeAAAA},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
domain, server, rrType, err := parseResolver(test.resolver)
|
||||
if err != nil {
|
||||
t.Fatalf("parseResolver(%q) err = %v", test.resolver, err)
|
||||
}
|
||||
if domain.String() != "example.com" || server != "1.1.1.1:53" || rrType != test.rrType {
|
||||
t.Fatalf("parseResolver(%q) = (%q, %q, %d)", test.resolver, domain.String(), server, rrType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDomainSpec(t *testing.T) {
|
||||
tests := []struct {
|
||||
spec string
|
||||
def string
|
||||
rrType uint16
|
||||
wantErr bool
|
||||
}{
|
||||
{"example.com", "", 0, false},
|
||||
{"example.com", "txt", RRTypeTXT, false},
|
||||
{"example.com:a", "", RRTypeA, false},
|
||||
{"example.com:aaaa", "", RRTypeAAAA, false},
|
||||
{"example.com:doh", "", 0, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
got, err := parseDomainSpec(test.spec, test.def)
|
||||
if test.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("parseDomainSpec(%q, %q) err = nil", test.spec, test.def)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("parseDomainSpec(%q, %q) err = %v", test.spec, test.def, err)
|
||||
}
|
||||
if got.name.String() != "example.com" || got.rrType != test.rrType {
|
||||
t.Fatalf("parseDomainSpec(%q, %q) = (%q, %d)", test.spec, test.def, got.name.String(), got.rrType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseForMethodRestriction(t *testing.T) {
|
||||
query := &Message{
|
||||
ID: 1,
|
||||
Flags: 0x0100,
|
||||
Question: []Question{{
|
||||
Name: mustParseName("abc.example.com"),
|
||||
Type: RRTypeTXT,
|
||||
Class: ClassIN,
|
||||
}},
|
||||
Additional: []RR{{
|
||||
Name: Name{},
|
||||
Type: RRTypeOPT,
|
||||
Class: 4096,
|
||||
}},
|
||||
}
|
||||
|
||||
resp, _ := responseFor(query, []domainSpec{{name: mustParseName("example.com"), rrType: RRTypeA}})
|
||||
if resp == nil || resp.Rcode() != RcodeNameError {
|
||||
t.Fatalf("responseFor method restriction rcode = %v", resp)
|
||||
}
|
||||
|
||||
resp, _ = responseFor(query, []domainSpec{{name: mustParseName("example.com")}})
|
||||
if resp == nil || resp.Rcode() != RcodeNoError {
|
||||
t.Fatalf("responseFor unrestricted rcode = %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
package xdns
|
||||
|
||||
import "bytes"
|
||||
|
||||
const ipRecordHeaderSize = 2
|
||||
|
||||
func maxEncodedPayloadForType(rrType uint16) int {
|
||||
switch rrType {
|
||||
case RRTypeA:
|
||||
return maxEncodedPayloadA
|
||||
case RRTypeAAAA:
|
||||
return maxEncodedPayloadAAAA
|
||||
default:
|
||||
return maxEncodedPayloadTXT
|
||||
}
|
||||
}
|
||||
|
||||
func rrDataSizeForType(rrType uint16) int {
|
||||
switch rrType {
|
||||
case RRTypeA:
|
||||
return 4
|
||||
case RRTypeAAAA:
|
||||
return 16
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func payloadChunkSizeForType(rrType uint16) int {
|
||||
size := rrDataSizeForType(rrType)
|
||||
if size <= ipRecordHeaderSize {
|
||||
return 0
|
||||
}
|
||||
return size - ipRecordHeaderSize
|
||||
}
|
||||
|
||||
func answersForPayload(question Question, ttl uint32, payload []byte) ([]RR, error) {
|
||||
switch question.Type {
|
||||
case RRTypeTXT:
|
||||
return []RR{
|
||||
{
|
||||
Name: question.Name,
|
||||
Type: question.Type,
|
||||
Class: question.Class,
|
||||
TTL: ttl,
|
||||
Data: EncodeRDataTXT(payload),
|
||||
},
|
||||
}, nil
|
||||
case RRTypeA, RRTypeAAAA:
|
||||
return ipAnswersForPayload(question, ttl, payload)
|
||||
default:
|
||||
return nil, ErrIntegerOverflow
|
||||
}
|
||||
}
|
||||
|
||||
func ipAnswersForPayload(question Question, ttl uint32, payload []byte) ([]RR, error) {
|
||||
chunkSize := payloadChunkSizeForType(question.Type)
|
||||
rrDataSize := rrDataSizeForType(question.Type)
|
||||
if chunkSize == 0 || rrDataSize == 0 {
|
||||
return nil, ErrIntegerOverflow
|
||||
}
|
||||
|
||||
numRecords := 1
|
||||
if len(payload) > 0 {
|
||||
numRecords = (len(payload) + chunkSize - 1) / chunkSize
|
||||
}
|
||||
if numRecords > 256 {
|
||||
return nil, ErrIntegerOverflow
|
||||
}
|
||||
|
||||
answers := make([]RR, 0, numRecords)
|
||||
for i := 0; i < numRecords; i++ {
|
||||
offset := i * chunkSize
|
||||
n := len(payload) - offset
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
if n > chunkSize {
|
||||
n = chunkSize
|
||||
}
|
||||
|
||||
data := make([]byte, rrDataSize)
|
||||
data[0] = byte(i)
|
||||
data[1] = byte(n)
|
||||
copy(data[ipRecordHeaderSize:], payload[offset:offset+n])
|
||||
|
||||
answers = append(answers, RR{
|
||||
Name: question.Name,
|
||||
Type: question.Type,
|
||||
Class: question.Class,
|
||||
TTL: ttl,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
return answers, nil
|
||||
}
|
||||
|
||||
func decodeResponsePayload(answers []RR) []byte {
|
||||
if len(answers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch answers[0].Type {
|
||||
case RRTypeTXT:
|
||||
if len(answers) != 1 {
|
||||
return nil
|
||||
}
|
||||
payload, err := DecodeRDataTXT(answers[0].Data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return payload
|
||||
case RRTypeA, RRTypeAAAA:
|
||||
return decodeIPAnswerPayload(answers, answers[0].Type)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func decodeIPAnswerPayload(answers []RR, rrType uint16) []byte {
|
||||
chunkSize := payloadChunkSizeForType(rrType)
|
||||
rrDataSize := rrDataSizeForType(rrType)
|
||||
if chunkSize == 0 || rrDataSize == 0 || len(answers) > 256 {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := make([][]byte, len(answers))
|
||||
for _, answer := range answers {
|
||||
if answer.Type != rrType || len(answer.Data) != rrDataSize {
|
||||
return nil
|
||||
}
|
||||
idx := int(answer.Data[0])
|
||||
n := int(answer.Data[1])
|
||||
if idx >= len(answers) || n > chunkSize || parts[idx] != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
part := make([]byte, n)
|
||||
copy(part, answer.Data[ipRecordHeaderSize:ipRecordHeaderSize+n])
|
||||
parts[idx] = part
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
for _, part := range parts {
|
||||
if part == nil {
|
||||
return nil
|
||||
}
|
||||
payload.Write(part)
|
||||
}
|
||||
return payload.Bytes()
|
||||
}
|
||||
|
||||
func computeMaxEncodedPayload(limit int) int {
|
||||
return computeMaxEncodedPayloadForType(limit, RRTypeTXT)
|
||||
}
|
||||
|
||||
func computeMaxEncodedPayloadForType(limit int, rrType uint16) int {
|
||||
maxLengthName, err := NewName([][]byte{
|
||||
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
{
|
||||
n := 0
|
||||
for _, label := range maxLengthName {
|
||||
n += len(label) + 1
|
||||
}
|
||||
n += 1
|
||||
if n != 255 {
|
||||
panic("computeMaxEncodedPayload n != 255")
|
||||
}
|
||||
}
|
||||
|
||||
queryLimit := uint16(limit)
|
||||
if int(queryLimit) != limit {
|
||||
queryLimit = 0xffff
|
||||
}
|
||||
query := &Message{
|
||||
Question: []Question{
|
||||
{
|
||||
Name: maxLengthName,
|
||||
Type: rrType,
|
||||
Class: ClassIN,
|
||||
},
|
||||
},
|
||||
Additional: []RR{
|
||||
{
|
||||
Name: Name{},
|
||||
Type: RRTypeOPT,
|
||||
Class: queryLimit,
|
||||
TTL: 0,
|
||||
Data: []byte{},
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, _ := responseFor(query, []domainSpec{{name: Name{[]byte{}}}})
|
||||
|
||||
low := 0
|
||||
high := 32768
|
||||
if chunkSize := payloadChunkSizeForType(rrType); chunkSize > 0 {
|
||||
high = 256*chunkSize + 1
|
||||
}
|
||||
for low+1 < high {
|
||||
mid := (low + high) / 2
|
||||
resp.Answer, err = answersForPayload(query.Question[0], responseTTL, make([]byte, mid))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
buf, err := resp.WireFormat()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(buf) <= limit {
|
||||
low = mid
|
||||
} else {
|
||||
high = mid
|
||||
}
|
||||
}
|
||||
|
||||
return low
|
||||
}
|
||||
@@ -21,8 +21,10 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
maxUDPPayload = 1280 - 40 - 8
|
||||
maxEncodedPayload = computeMaxEncodedPayload(maxUDPPayload)
|
||||
maxUDPPayload = 1280 - 40 - 8
|
||||
maxEncodedPayloadTXT = computeMaxEncodedPayloadForType(maxUDPPayload, RRTypeTXT)
|
||||
maxEncodedPayloadA = computeMaxEncodedPayloadForType(maxUDPPayload, RRTypeA)
|
||||
maxEncodedPayloadAAAA = computeMaxEncodedPayloadForType(maxUDPPayload, RRTypeAAAA)
|
||||
)
|
||||
|
||||
func clientIDToAddr(clientID [8]byte) *net.UDPAddr {
|
||||
@@ -44,15 +46,16 @@ type record struct {
|
||||
}
|
||||
|
||||
type queue struct {
|
||||
last time.Time
|
||||
queue chan []byte
|
||||
stash chan []byte
|
||||
last time.Time
|
||||
rrType uint16
|
||||
queue chan []byte
|
||||
stash chan []byte
|
||||
}
|
||||
|
||||
type xdnsConnServer struct {
|
||||
net.PacketConn
|
||||
|
||||
domains []Name
|
||||
domains []domainSpec
|
||||
|
||||
ch chan *record
|
||||
readQueue chan *packet
|
||||
@@ -66,9 +69,9 @@ func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) {
|
||||
if len(c.Domains) == 0 {
|
||||
return nil, errors.New("empty domains")
|
||||
}
|
||||
domains := make([]Name, 0, len(c.Domains))
|
||||
domains := make([]domainSpec, 0, len(c.Domains))
|
||||
for _, domain := range c.Domains {
|
||||
domain, err := ParseName(domain)
|
||||
domain, err := parseDomainSpec(domain, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -234,6 +237,7 @@ func (c *xdnsConnServer) recvLoop() {
|
||||
func (c *xdnsConnServer) sendLoop() {
|
||||
var nextRec *record
|
||||
for {
|
||||
var err error
|
||||
rec := nextRec
|
||||
nextRec = nil
|
||||
|
||||
@@ -246,18 +250,8 @@ func (c *xdnsConnServer) sendLoop() {
|
||||
}
|
||||
|
||||
if rec.Resp.Rcode() == RcodeNoError && len(rec.Resp.Question) == 1 {
|
||||
rec.Resp.Answer = []RR{
|
||||
{
|
||||
Name: rec.Resp.Question[0].Name,
|
||||
Type: rec.Resp.Question[0].Type,
|
||||
Class: rec.Resp.Question[0].Class,
|
||||
TTL: responseTTL,
|
||||
Data: nil,
|
||||
},
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
limit := maxEncodedPayload
|
||||
limit := maxEncodedPayloadForType(rec.Resp.Question[0].Type)
|
||||
timer := time.NewTimer(maxResponseDelay)
|
||||
|
||||
for {
|
||||
@@ -267,6 +261,7 @@ func (c *xdnsConnServer) sendLoop() {
|
||||
c.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
q.rrType = rec.Resp.Question[0].Type
|
||||
c.mutex.Unlock()
|
||||
|
||||
var p []byte
|
||||
@@ -294,7 +289,11 @@ func (c *xdnsConnServer) sendLoop() {
|
||||
}
|
||||
|
||||
limit -= 2 + len(p)
|
||||
if payload.Len() > 0 && limit < 0 {
|
||||
if limit < 0 {
|
||||
if payload.Len() == 0 {
|
||||
errors.LogDebug(context.Background(), rec.Addr, " ", rec.ClientAddr, " xdns payload too large for rrtype ", rec.Resp.Question[0].Type, " ", len(p))
|
||||
continue
|
||||
}
|
||||
c.stash(q, p)
|
||||
break
|
||||
}
|
||||
@@ -308,7 +307,11 @@ func (c *xdnsConnServer) sendLoop() {
|
||||
}
|
||||
|
||||
timer.Stop()
|
||||
rec.Resp.Answer[0].Data = EncodeRDataTXT(payload.Bytes())
|
||||
rec.Resp.Answer, err = answersForPayload(rec.Resp.Question[0], responseTTL, payload.Bytes())
|
||||
if err != nil {
|
||||
errors.LogDebug(context.Background(), rec.Addr, " ", rec.ClientAddr, " xdns encode err ", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
buf, err := rec.Resp.WireFormat()
|
||||
@@ -349,11 +352,6 @@ func (c *xdnsConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
}
|
||||
|
||||
func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if len(p)+2 > maxEncodedPayload {
|
||||
errors.LogDebug(context.Background(), addr, " mask write err short write ", len(p), "+2 > ", maxEncodedPayload)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
@@ -361,6 +359,14 @@ func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
if q == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
limit := maxEncodedPayloadForType(q.rrType)
|
||||
if q.rrType == 0 {
|
||||
limit = maxEncodedPayloadTXT
|
||||
}
|
||||
if len(p)+2 > limit {
|
||||
errors.LogDebug(context.Background(), addr, " mask write err short write ", len(p), "+2 > ", limit)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
buf := make([]byte, len(p))
|
||||
copy(buf, p)
|
||||
@@ -406,7 +412,7 @@ func nextPacketServer(r *bytes.Reader) ([]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func responseFor(query *Message, domains []Name) (*Message, []byte) {
|
||||
func responseFor(query *Message, domains []domainSpec) (*Message, []byte) {
|
||||
resp := &Message{
|
||||
ID: query.ID,
|
||||
Flags: 0x8000,
|
||||
@@ -454,11 +460,15 @@ func responseFor(query *Message, domains []Name) (*Message, []byte) {
|
||||
}
|
||||
question := query.Question[0]
|
||||
|
||||
var prefix Name
|
||||
var ok bool
|
||||
var (
|
||||
prefix Name
|
||||
ok bool
|
||||
match domainSpec
|
||||
)
|
||||
for _, domain := range domains {
|
||||
prefix, ok = question.Name.TrimSuffix(domain)
|
||||
prefix, ok = question.Name.TrimSuffix(domain.name)
|
||||
if ok {
|
||||
match = domain
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -473,7 +483,13 @@ func responseFor(query *Message, domains []Name) (*Message, []byte) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if question.Type != RRTypeTXT {
|
||||
switch question.Type {
|
||||
case RRTypeTXT, RRTypeA, RRTypeAAAA:
|
||||
default:
|
||||
resp.Flags |= RcodeNameError
|
||||
return resp, nil
|
||||
}
|
||||
if match.rrType != 0 && question.Type != match.rrType {
|
||||
resp.Flags |= RcodeNameError
|
||||
return resp, nil
|
||||
}
|
||||
@@ -494,78 +510,3 @@ func responseFor(query *Message, domains []Name) (*Message, []byte) {
|
||||
|
||||
return resp, payload
|
||||
}
|
||||
|
||||
func computeMaxEncodedPayload(limit int) int {
|
||||
maxLengthName, err := NewName([][]byte{
|
||||
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
{
|
||||
n := 0
|
||||
for _, label := range maxLengthName {
|
||||
n += len(label) + 1
|
||||
}
|
||||
n += 1
|
||||
if n != 255 {
|
||||
panic("computeMaxEncodedPayload n != 255")
|
||||
}
|
||||
}
|
||||
|
||||
queryLimit := uint16(limit)
|
||||
if int(queryLimit) != limit {
|
||||
queryLimit = 0xffff
|
||||
}
|
||||
query := &Message{
|
||||
Question: []Question{
|
||||
{
|
||||
Name: maxLengthName,
|
||||
Type: RRTypeTXT,
|
||||
Class: RRTypeTXT,
|
||||
},
|
||||
},
|
||||
|
||||
Additional: []RR{
|
||||
{
|
||||
Name: Name{},
|
||||
Type: RRTypeOPT,
|
||||
Class: queryLimit,
|
||||
TTL: 0,
|
||||
Data: []byte{},
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, _ := responseFor(query, []Name{[][]byte{}})
|
||||
|
||||
resp.Answer = []RR{
|
||||
{
|
||||
Name: query.Question[0].Name,
|
||||
Type: query.Question[0].Type,
|
||||
Class: query.Question[0].Class,
|
||||
TTL: responseTTL,
|
||||
Data: nil,
|
||||
},
|
||||
}
|
||||
|
||||
low := 0
|
||||
high := 32768
|
||||
for low+1 < high {
|
||||
mid := (low + high) / 2
|
||||
resp.Answer[0].Data = EncodeRDataTXT(make([]byte, mid))
|
||||
buf, err := resp.WireFormat()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(buf) <= limit {
|
||||
low = mid
|
||||
} else {
|
||||
high = mid
|
||||
}
|
||||
}
|
||||
|
||||
return low
|
||||
}
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
package xdns
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
)
|
||||
|
||||
type domainSpec struct {
|
||||
name Name
|
||||
rrType uint16
|
||||
}
|
||||
|
||||
func rrTypeFromMethod(method string) (uint16, error) {
|
||||
switch strings.ToLower(method) {
|
||||
case "", "txt":
|
||||
return RRTypeTXT, nil
|
||||
case "a":
|
||||
return RRTypeA, nil
|
||||
case "aaaa":
|
||||
return RRTypeAAAA, nil
|
||||
default:
|
||||
return 0, errors.New("unsupported method")
|
||||
}
|
||||
}
|
||||
|
||||
func parseDomainSpec(s string, defaultMethod string) (domainSpec, error) {
|
||||
domainPart := s
|
||||
method := ""
|
||||
hasMethod := false
|
||||
|
||||
if i := strings.LastIndex(s, ":"); i >= 0 {
|
||||
domainPart = s[:i]
|
||||
method = s[i+1:]
|
||||
hasMethod = true
|
||||
} else if defaultMethod != "" {
|
||||
method = defaultMethod
|
||||
hasMethod = true
|
||||
}
|
||||
|
||||
if domainPart == "" {
|
||||
return domainSpec{}, errors.New("empty domain")
|
||||
}
|
||||
|
||||
name, err := ParseName(domainPart)
|
||||
if err != nil {
|
||||
return domainSpec{}, err
|
||||
}
|
||||
|
||||
rrType := uint16(0)
|
||||
if hasMethod {
|
||||
var err error
|
||||
rrType, err = rrTypeFromMethod(method)
|
||||
if err != nil {
|
||||
return domainSpec{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return domainSpec{
|
||||
name: name,
|
||||
rrType: rrType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseResolver(s string) (Name, string, uint16, error) {
|
||||
head, server, ok := strings.Cut(s, "+udp://")
|
||||
if !ok {
|
||||
return nil, "", 0, errors.New("invalid resolver scheme")
|
||||
}
|
||||
if server == "" {
|
||||
return nil, "", 0, errors.New("empty resolver server")
|
||||
}
|
||||
|
||||
spec, err := parseDomainSpec(head, "txt")
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
return spec.name, server, spec.rrType, nil
|
||||
}
|
||||
Reference in New Issue
Block a user