diff --git a/transport/internet/finalmask/xdns/client.go b/transport/internet/finalmask/xdns/client.go index 7ceb7c50..7513a0a9 100644 --- a/transport/internet/finalmask/xdns/client.go +++ b/transport/internet/finalmask/xdns/client.go @@ -10,7 +10,6 @@ import ( "io" "net" "strconv" - "strings" "sync" "sync/atomic" "time" @@ -40,6 +39,7 @@ type xdnsConnClient struct { net.PacketConn resolverAddrs []*net.UDPAddr + resolverTypes []uint16 resolverIdx uint32 resolverSend map[string]*atomic.Uint32 @@ -61,17 +61,15 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { var domains []Name var servers []string + var resolverTypes []uint16 for _, rs := range c.Resolvers { - parts := strings.Split(rs, "+udp://") - if len(parts) != 2 { - return nil, errors.New("invalid resolvers") - } - domain, err := ParseName(parts[0]) + domain, server, resolverType, err := parseResolver(rs) if err != nil { - return nil, err + return nil, errors.New("invalid resolvers").Base(err) } domains = append(domains, domain) - servers = append(servers, parts[1]) + servers = append(servers, server) + resolverTypes = append(resolverTypes, resolverType) } var resolverAddrs []*net.UDPAddr @@ -98,6 +96,7 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { PacketConn: raw, resolverAddrs: resolverAddrs, + resolverTypes: resolverTypes, resolverIdx: 0, resolverSend: resolverSend, @@ -216,7 +215,7 @@ func (c *xdnsConnClient) sendLoop() { default: } } else { - encoded, _ := encode(nil, c.clientID, c.domains[c.resolverIdx]) + encoded, _ := encode(nil, c.clientID, c.domains[c.resolverIdx], c.resolverTypes[c.resolverIdx]) p = &packet{ p: encoded, } @@ -276,7 +275,8 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { return 0, io.ErrClosedPipe } - encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverAddrs))]) + idx := c.resolverIdx % uint32(len(c.resolverAddrs)) + encoded, err := encode(p, c.clientID, c.domains[idx], c.resolverTypes[idx]) if err != nil { errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p)) return 0, nil @@ -299,7 +299,7 @@ func (c *xdnsConnClient) Close() error { return c.PacketConn.Close() } -func encode(p []byte, clientID []byte, domain Name) ([]byte, error) { +func encode(p []byte, clientID []byte, domain Name, qtype uint16) ([]byte, error) { var decoded []byte { if len(p) >= 224 { @@ -338,7 +338,7 @@ func encode(p []byte, clientID []byte, domain Name) ([]byte, error) { Question: []Question{ { Name: name, - Type: RRTypeTXT, + Type: qtype, Class: ClassIN, }, }, @@ -396,29 +396,22 @@ func dnsResponsePayload(resp *Message, domains []Name) []byte { return nil } - if len(resp.Answer) != 1 { + if len(resp.Answer) == 0 { return nil } - answer := resp.Answer[0] - var ok bool - for _, domain := range domains { - _, ok = answer.Name.TrimSuffix(domain) - if ok { - break + for _, answer := range resp.Answer { + var ok bool + for _, domain := range domains { + _, ok = answer.Name.TrimSuffix(domain) + if ok { + break + } + } + if !ok { + return nil } } - if !ok { - return nil - } - if answer.Type != RRTypeTXT { - return nil - } - payload, err := DecodeRDataTXT(answer.Data) - if err != nil { - return nil - } - - return payload + return decodeResponsePayload(resp.Answer) } diff --git a/transport/internet/finalmask/xdns/dns.go b/transport/internet/finalmask/xdns/dns.go index 4cdac7cd..77490385 100644 --- a/transport/internet/finalmask/xdns/dns.go +++ b/transport/internet/finalmask/xdns/dns.go @@ -45,8 +45,14 @@ var ( ) const ( + // https://tools.ietf.org/html/rfc1035#section-3.2.2 + RRTypeA = 1 + // https://tools.ietf.org/html/rfc1035#section-3.2.2 + RRTypeCNAME = 5 // https://tools.ietf.org/html/rfc1035#section-3.2.2 RRTypeTXT = 16 + // https://tools.ietf.org/html/rfc3596#section-2.1 + RRTypeAAAA = 28 // https://tools.ietf.org/html/rfc6891#section-6.1.1 RRTypeOPT = 41 diff --git a/transport/internet/finalmask/xdns/dns_test.go b/transport/internet/finalmask/xdns/dns_test.go index 2ddc9da5..45da317d 100644 --- a/transport/internet/finalmask/xdns/dns_test.go +++ b/transport/internet/finalmask/xdns/dns_test.go @@ -593,3 +593,113 @@ func TestRDataTXTRoundTrip(t *testing.T) { } } } + +func TestIPAnswerPayloadRoundTrip(t *testing.T) { + for _, rrType := range []uint16{RRTypeA, RRTypeAAAA} { + for _, payload := range [][]byte{ + {}, + {0x01}, + []byte("hello world"), + bytes.Repeat([]byte{0xab}, payloadChunkSizeForType(rrType)*3+1), + } { + question := Question{ + Name: mustParseName("example.com"), + Type: rrType, + Class: ClassIN, + } + answers, err := answersForPayload(question, responseTTL, payload) + if err != nil { + t.Fatalf("answersForPayload(%d) err = %v", rrType, err) + } + + if len(answers) > 1 { + answers[0], answers[len(answers)-1] = answers[len(answers)-1], answers[0] + } + + decoded := decodeResponsePayload(answers) + if !bytes.Equal(decoded, payload) { + t.Fatalf("rrType=%d decoded %x want %x", rrType, decoded, payload) + } + } + } +} + +func TestParseResolver(t *testing.T) { + tests := []struct { + resolver string + rrType uint16 + }{ + {"example.com+udp://1.1.1.1:53", RRTypeTXT}, + {"example.com:txt+udp://1.1.1.1:53", RRTypeTXT}, + {"example.com:a+udp://1.1.1.1:53", RRTypeA}, + {"example.com:aaaa+udp://1.1.1.1:53", RRTypeAAAA}, + } + + for _, test := range tests { + domain, server, rrType, err := parseResolver(test.resolver) + if err != nil { + t.Fatalf("parseResolver(%q) err = %v", test.resolver, err) + } + if domain.String() != "example.com" || server != "1.1.1.1:53" || rrType != test.rrType { + t.Fatalf("parseResolver(%q) = (%q, %q, %d)", test.resolver, domain.String(), server, rrType) + } + } +} + +func TestParseDomainSpec(t *testing.T) { + tests := []struct { + spec string + def string + rrType uint16 + wantErr bool + }{ + {"example.com", "", 0, false}, + {"example.com", "txt", RRTypeTXT, false}, + {"example.com:a", "", RRTypeA, false}, + {"example.com:aaaa", "", RRTypeAAAA, false}, + {"example.com:doh", "", 0, true}, + } + + for _, test := range tests { + got, err := parseDomainSpec(test.spec, test.def) + if test.wantErr { + if err == nil { + t.Fatalf("parseDomainSpec(%q, %q) err = nil", test.spec, test.def) + } + continue + } + if err != nil { + t.Fatalf("parseDomainSpec(%q, %q) err = %v", test.spec, test.def, err) + } + if got.name.String() != "example.com" || got.rrType != test.rrType { + t.Fatalf("parseDomainSpec(%q, %q) = (%q, %d)", test.spec, test.def, got.name.String(), got.rrType) + } + } +} + +func TestResponseForMethodRestriction(t *testing.T) { + query := &Message{ + ID: 1, + Flags: 0x0100, + Question: []Question{{ + Name: mustParseName("abc.example.com"), + Type: RRTypeTXT, + Class: ClassIN, + }}, + Additional: []RR{{ + Name: Name{}, + Type: RRTypeOPT, + Class: 4096, + }}, + } + + resp, _ := responseFor(query, []domainSpec{{name: mustParseName("example.com"), rrType: RRTypeA}}) + if resp == nil || resp.Rcode() != RcodeNameError { + t.Fatalf("responseFor method restriction rcode = %v", resp) + } + + resp, _ = responseFor(query, []domainSpec{{name: mustParseName("example.com")}}) + if resp == nil || resp.Rcode() != RcodeNoError { + t.Fatalf("responseFor unrestricted rcode = %v", resp) + } +} diff --git a/transport/internet/finalmask/xdns/record_transport.go b/transport/internet/finalmask/xdns/record_transport.go new file mode 100644 index 00000000..8428baa4 --- /dev/null +++ b/transport/internet/finalmask/xdns/record_transport.go @@ -0,0 +1,226 @@ +package xdns + +import "bytes" + +const ipRecordHeaderSize = 2 + +func maxEncodedPayloadForType(rrType uint16) int { + switch rrType { + case RRTypeA: + return maxEncodedPayloadA + case RRTypeAAAA: + return maxEncodedPayloadAAAA + default: + return maxEncodedPayloadTXT + } +} + +func rrDataSizeForType(rrType uint16) int { + switch rrType { + case RRTypeA: + return 4 + case RRTypeAAAA: + return 16 + default: + return 0 + } +} + +func payloadChunkSizeForType(rrType uint16) int { + size := rrDataSizeForType(rrType) + if size <= ipRecordHeaderSize { + return 0 + } + return size - ipRecordHeaderSize +} + +func answersForPayload(question Question, ttl uint32, payload []byte) ([]RR, error) { + switch question.Type { + case RRTypeTXT: + return []RR{ + { + Name: question.Name, + Type: question.Type, + Class: question.Class, + TTL: ttl, + Data: EncodeRDataTXT(payload), + }, + }, nil + case RRTypeA, RRTypeAAAA: + return ipAnswersForPayload(question, ttl, payload) + default: + return nil, ErrIntegerOverflow + } +} + +func ipAnswersForPayload(question Question, ttl uint32, payload []byte) ([]RR, error) { + chunkSize := payloadChunkSizeForType(question.Type) + rrDataSize := rrDataSizeForType(question.Type) + if chunkSize == 0 || rrDataSize == 0 { + return nil, ErrIntegerOverflow + } + + numRecords := 1 + if len(payload) > 0 { + numRecords = (len(payload) + chunkSize - 1) / chunkSize + } + if numRecords > 256 { + return nil, ErrIntegerOverflow + } + + answers := make([]RR, 0, numRecords) + for i := 0; i < numRecords; i++ { + offset := i * chunkSize + n := len(payload) - offset + if n < 0 { + n = 0 + } + if n > chunkSize { + n = chunkSize + } + + data := make([]byte, rrDataSize) + data[0] = byte(i) + data[1] = byte(n) + copy(data[ipRecordHeaderSize:], payload[offset:offset+n]) + + answers = append(answers, RR{ + Name: question.Name, + Type: question.Type, + Class: question.Class, + TTL: ttl, + Data: data, + }) + } + + return answers, nil +} + +func decodeResponsePayload(answers []RR) []byte { + if len(answers) == 0 { + return nil + } + + switch answers[0].Type { + case RRTypeTXT: + if len(answers) != 1 { + return nil + } + payload, err := DecodeRDataTXT(answers[0].Data) + if err != nil { + return nil + } + return payload + case RRTypeA, RRTypeAAAA: + return decodeIPAnswerPayload(answers, answers[0].Type) + default: + return nil + } +} + +func decodeIPAnswerPayload(answers []RR, rrType uint16) []byte { + chunkSize := payloadChunkSizeForType(rrType) + rrDataSize := rrDataSizeForType(rrType) + if chunkSize == 0 || rrDataSize == 0 || len(answers) > 256 { + return nil + } + + parts := make([][]byte, len(answers)) + for _, answer := range answers { + if answer.Type != rrType || len(answer.Data) != rrDataSize { + return nil + } + idx := int(answer.Data[0]) + n := int(answer.Data[1]) + if idx >= len(answers) || n > chunkSize || parts[idx] != nil { + return nil + } + + part := make([]byte, n) + copy(part, answer.Data[ipRecordHeaderSize:ipRecordHeaderSize+n]) + parts[idx] = part + } + + var payload bytes.Buffer + for _, part := range parts { + if part == nil { + return nil + } + payload.Write(part) + } + return payload.Bytes() +} + +func computeMaxEncodedPayload(limit int) int { + return computeMaxEncodedPayloadForType(limit, RRTypeTXT) +} + +func computeMaxEncodedPayloadForType(limit int, rrType uint16) int { + maxLengthName, err := NewName([][]byte{ + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + }) + if err != nil { + panic(err) + } + { + n := 0 + for _, label := range maxLengthName { + n += len(label) + 1 + } + n += 1 + if n != 255 { + panic("computeMaxEncodedPayload n != 255") + } + } + + queryLimit := uint16(limit) + if int(queryLimit) != limit { + queryLimit = 0xffff + } + query := &Message{ + Question: []Question{ + { + Name: maxLengthName, + Type: rrType, + Class: ClassIN, + }, + }, + Additional: []RR{ + { + Name: Name{}, + Type: RRTypeOPT, + Class: queryLimit, + TTL: 0, + Data: []byte{}, + }, + }, + } + resp, _ := responseFor(query, []domainSpec{{name: Name{[]byte{}}}}) + + low := 0 + high := 32768 + if chunkSize := payloadChunkSizeForType(rrType); chunkSize > 0 { + high = 256*chunkSize + 1 + } + for low+1 < high { + mid := (low + high) / 2 + resp.Answer, err = answersForPayload(query.Question[0], responseTTL, make([]byte, mid)) + if err != nil { + panic(err) + } + buf, err := resp.WireFormat() + if err != nil { + panic(err) + } + if len(buf) <= limit { + low = mid + } else { + high = mid + } + } + + return low +} diff --git a/transport/internet/finalmask/xdns/server.go b/transport/internet/finalmask/xdns/server.go index c96149ad..654f7fdb 100644 --- a/transport/internet/finalmask/xdns/server.go +++ b/transport/internet/finalmask/xdns/server.go @@ -21,8 +21,10 @@ const ( ) var ( - maxUDPPayload = 1280 - 40 - 8 - maxEncodedPayload = computeMaxEncodedPayload(maxUDPPayload) + maxUDPPayload = 1280 - 40 - 8 + maxEncodedPayloadTXT = computeMaxEncodedPayloadForType(maxUDPPayload, RRTypeTXT) + maxEncodedPayloadA = computeMaxEncodedPayloadForType(maxUDPPayload, RRTypeA) + maxEncodedPayloadAAAA = computeMaxEncodedPayloadForType(maxUDPPayload, RRTypeAAAA) ) func clientIDToAddr(clientID [8]byte) *net.UDPAddr { @@ -44,15 +46,16 @@ type record struct { } type queue struct { - last time.Time - queue chan []byte - stash chan []byte + last time.Time + rrType uint16 + queue chan []byte + stash chan []byte } type xdnsConnServer struct { net.PacketConn - domains []Name + domains []domainSpec ch chan *record readQueue chan *packet @@ -66,9 +69,9 @@ func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { if len(c.Domains) == 0 { return nil, errors.New("empty domains") } - domains := make([]Name, 0, len(c.Domains)) + domains := make([]domainSpec, 0, len(c.Domains)) for _, domain := range c.Domains { - domain, err := ParseName(domain) + domain, err := parseDomainSpec(domain, "") if err != nil { return nil, err } @@ -234,6 +237,7 @@ func (c *xdnsConnServer) recvLoop() { func (c *xdnsConnServer) sendLoop() { var nextRec *record for { + var err error rec := nextRec nextRec = nil @@ -246,18 +250,8 @@ func (c *xdnsConnServer) sendLoop() { } if rec.Resp.Rcode() == RcodeNoError && len(rec.Resp.Question) == 1 { - rec.Resp.Answer = []RR{ - { - Name: rec.Resp.Question[0].Name, - Type: rec.Resp.Question[0].Type, - Class: rec.Resp.Question[0].Class, - TTL: responseTTL, - Data: nil, - }, - } - var payload bytes.Buffer - limit := maxEncodedPayload + limit := maxEncodedPayloadForType(rec.Resp.Question[0].Type) timer := time.NewTimer(maxResponseDelay) for { @@ -267,6 +261,7 @@ func (c *xdnsConnServer) sendLoop() { c.mutex.Unlock() return } + q.rrType = rec.Resp.Question[0].Type c.mutex.Unlock() var p []byte @@ -294,7 +289,11 @@ func (c *xdnsConnServer) sendLoop() { } limit -= 2 + len(p) - if payload.Len() > 0 && limit < 0 { + if limit < 0 { + if payload.Len() == 0 { + errors.LogDebug(context.Background(), rec.Addr, " ", rec.ClientAddr, " xdns payload too large for rrtype ", rec.Resp.Question[0].Type, " ", len(p)) + continue + } c.stash(q, p) break } @@ -308,7 +307,11 @@ func (c *xdnsConnServer) sendLoop() { } timer.Stop() - rec.Resp.Answer[0].Data = EncodeRDataTXT(payload.Bytes()) + rec.Resp.Answer, err = answersForPayload(rec.Resp.Question[0], responseTTL, payload.Bytes()) + if err != nil { + errors.LogDebug(context.Background(), rec.Addr, " ", rec.ClientAddr, " xdns encode err ", err) + continue + } } buf, err := rec.Resp.WireFormat() @@ -349,11 +352,6 @@ func (c *xdnsConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if len(p)+2 > maxEncodedPayload { - errors.LogDebug(context.Background(), addr, " mask write err short write ", len(p), "+2 > ", maxEncodedPayload) - return 0, nil - } - c.mutex.Lock() defer c.mutex.Unlock() @@ -361,6 +359,14 @@ func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { if q == nil { return 0, io.ErrClosedPipe } + limit := maxEncodedPayloadForType(q.rrType) + if q.rrType == 0 { + limit = maxEncodedPayloadTXT + } + if len(p)+2 > limit { + errors.LogDebug(context.Background(), addr, " mask write err short write ", len(p), "+2 > ", limit) + return 0, nil + } buf := make([]byte, len(p)) copy(buf, p) @@ -406,7 +412,7 @@ func nextPacketServer(r *bytes.Reader) ([]byte, error) { } } -func responseFor(query *Message, domains []Name) (*Message, []byte) { +func responseFor(query *Message, domains []domainSpec) (*Message, []byte) { resp := &Message{ ID: query.ID, Flags: 0x8000, @@ -454,11 +460,15 @@ func responseFor(query *Message, domains []Name) (*Message, []byte) { } question := query.Question[0] - var prefix Name - var ok bool + var ( + prefix Name + ok bool + match domainSpec + ) for _, domain := range domains { - prefix, ok = question.Name.TrimSuffix(domain) + prefix, ok = question.Name.TrimSuffix(domain.name) if ok { + match = domain break } } @@ -473,7 +483,13 @@ func responseFor(query *Message, domains []Name) (*Message, []byte) { return resp, nil } - if question.Type != RRTypeTXT { + switch question.Type { + case RRTypeTXT, RRTypeA, RRTypeAAAA: + default: + resp.Flags |= RcodeNameError + return resp, nil + } + if match.rrType != 0 && question.Type != match.rrType { resp.Flags |= RcodeNameError return resp, nil } @@ -494,78 +510,3 @@ func responseFor(query *Message, domains []Name) (*Message, []byte) { return resp, payload } - -func computeMaxEncodedPayload(limit int) int { - maxLengthName, err := NewName([][]byte{ - []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), - []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), - []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), - []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), - }) - if err != nil { - panic(err) - } - { - n := 0 - for _, label := range maxLengthName { - n += len(label) + 1 - } - n += 1 - if n != 255 { - panic("computeMaxEncodedPayload n != 255") - } - } - - queryLimit := uint16(limit) - if int(queryLimit) != limit { - queryLimit = 0xffff - } - query := &Message{ - Question: []Question{ - { - Name: maxLengthName, - Type: RRTypeTXT, - Class: RRTypeTXT, - }, - }, - - Additional: []RR{ - { - Name: Name{}, - Type: RRTypeOPT, - Class: queryLimit, - TTL: 0, - Data: []byte{}, - }, - }, - } - resp, _ := responseFor(query, []Name{[][]byte{}}) - - resp.Answer = []RR{ - { - Name: query.Question[0].Name, - Type: query.Question[0].Type, - Class: query.Question[0].Class, - TTL: responseTTL, - Data: nil, - }, - } - - low := 0 - high := 32768 - for low+1 < high { - mid := (low + high) / 2 - resp.Answer[0].Data = EncodeRDataTXT(make([]byte, mid)) - buf, err := resp.WireFormat() - if err != nil { - panic(err) - } - if len(buf) <= limit { - low = mid - } else { - high = mid - } - } - - return low -} diff --git a/transport/internet/finalmask/xdns/spec.go b/transport/internet/finalmask/xdns/spec.go new file mode 100644 index 00000000..28461569 --- /dev/null +++ b/transport/internet/finalmask/xdns/spec.go @@ -0,0 +1,80 @@ +package xdns + +import ( + "strings" + + "github.com/xtls/xray-core/common/errors" +) + +type domainSpec struct { + name Name + rrType uint16 +} + +func rrTypeFromMethod(method string) (uint16, error) { + switch strings.ToLower(method) { + case "", "txt": + return RRTypeTXT, nil + case "a": + return RRTypeA, nil + case "aaaa": + return RRTypeAAAA, nil + default: + return 0, errors.New("unsupported method") + } +} + +func parseDomainSpec(s string, defaultMethod string) (domainSpec, error) { + domainPart := s + method := "" + hasMethod := false + + if i := strings.LastIndex(s, ":"); i >= 0 { + domainPart = s[:i] + method = s[i+1:] + hasMethod = true + } else if defaultMethod != "" { + method = defaultMethod + hasMethod = true + } + + if domainPart == "" { + return domainSpec{}, errors.New("empty domain") + } + + name, err := ParseName(domainPart) + if err != nil { + return domainSpec{}, err + } + + rrType := uint16(0) + if hasMethod { + var err error + rrType, err = rrTypeFromMethod(method) + if err != nil { + return domainSpec{}, err + } + } + + return domainSpec{ + name: name, + rrType: rrType, + }, nil +} + +func parseResolver(s string) (Name, string, uint16, error) { + head, server, ok := strings.Cut(s, "+udp://") + if !ok { + return nil, "", 0, errors.New("invalid resolver scheme") + } + if server == "" { + return nil, "", 0, errors.New("empty resolver server") + } + + spec, err := parseDomainSpec(head, "txt") + if err != nil { + return nil, "", 0, err + } + + return spec.name, server, spec.rrType, nil +}