Files

363 lines
6.4 KiB
Go

package xdns
import (
"bytes"
"context"
"crypto/rand"
"encoding/base32"
"encoding/binary"
go_errors "errors"
"io"
"net"
"sync"
"time"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/transport/internet/finalmask"
)
const (
numPadding = 3
numPaddingForPoll = 8
initPollDelay = 500 * time.Millisecond
maxPollDelay = 10 * time.Second
pollDelayMultiplier = 2.0
pollLimit = 16
)
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
type packet struct {
p []byte
addr net.Addr
}
type xdnsConnClient struct {
net.PacketConn
clientID []byte
domain Name
pollChan chan struct{}
readQueue chan *packet
writeQueue chan *packet
closed bool
mutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
domain, err := ParseName(c.Domain)
if err != nil {
return nil, err
}
conn := &xdnsConnClient{
PacketConn: raw,
clientID: make([]byte, 8),
domain: domain,
pollChan: make(chan struct{}, pollLimit),
readQueue: make(chan *packet, 256),
writeQueue: make(chan *packet, 256),
}
common.Must2(rand.Read(conn.clientID))
go conn.recvLoop()
go conn.sendLoop()
return conn, nil
}
func (c *xdnsConnClient) recvLoop() {
var buf [finalmask.UDPSize]byte
for {
if c.closed {
break
}
n, addr, err := c.PacketConn.ReadFrom(buf[:])
if err != nil || n == 0 {
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) {
break
}
continue
}
resp, err := MessageFromWireFormat(buf[:n])
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
continue
}
payload := dnsResponsePayload(&resp, c.domain)
r := bytes.NewReader(payload)
anyPacket := false
for {
p, err := nextPacket(r)
if err != nil {
break
}
anyPacket = true
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.readQueue <- &packet{
p: buf,
addr: addr,
}:
default:
errors.LogDebug(context.Background(), addr, " mask read err queue full")
}
}
if anyPacket {
select {
case c.pollChan <- struct{}{}:
default:
}
}
}
errors.LogDebug(context.Background(), "xdns closed")
close(c.pollChan)
close(c.readQueue)
c.mutex.Lock()
defer c.mutex.Unlock()
c.closed = true
close(c.writeQueue)
}
func (c *xdnsConnClient) sendLoop() {
var addr net.Addr
pollDelay := initPollDelay
pollTimer := time.NewTimer(pollDelay)
for {
var p *packet
pollTimerExpired := false
select {
case p = <-c.writeQueue:
default:
select {
case p = <-c.writeQueue:
case <-c.pollChan:
case <-pollTimer.C:
pollTimerExpired = true
}
}
if p != nil {
addr = p.addr
select {
case <-c.pollChan:
default:
}
} else if addr != nil {
encoded, _ := encode(nil, c.clientID, c.domain)
p = &packet{
p: encoded,
addr: addr,
}
}
if pollTimerExpired {
pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier)
if pollDelay > maxPollDelay {
pollDelay = maxPollDelay
}
} else {
if !pollTimer.Stop() {
<-pollTimer.C
}
pollDelay = initPollDelay
}
pollTimer.Reset(pollDelay)
if c.closed {
return
}
if p != nil {
_, err := c.PacketConn.WriteTo(p.p, p.addr)
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) {
c.closed = true
break
}
}
}
}
func (c *xdnsConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
packet, ok := <-c.readQueue
if !ok {
return 0, nil, net.ErrClosed
}
if len(p) < len(packet.p) {
errors.LogDebug(context.Background(), packet.addr, " mask read err short buffer ", len(p), " ", len(packet.p))
return 0, packet.addr, nil
}
copy(p, packet.p)
return len(packet.p), packet.addr, nil
}
func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return 0, io.ErrClosedPipe
}
encoded, err := encode(p, c.clientID, c.domain)
if err != nil {
errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p))
return 0, nil
}
select {
case c.writeQueue <- &packet{
p: encoded,
addr: addr,
}:
return len(p), nil
default:
errors.LogDebug(context.Background(), addr, " mask write err queue full")
return 0, nil
}
}
func (c *xdnsConnClient) Close() error {
c.closed = true
return c.PacketConn.Close()
}
func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
var decoded []byte
{
if len(p) >= 224 {
return nil, errors.New("too long")
}
var buf bytes.Buffer
buf.Write(clientID[:])
n := numPadding
if len(p) == 0 {
n = numPaddingForPoll
}
buf.WriteByte(byte(224 + n))
_, _ = io.CopyN(&buf, rand.Reader, int64(n))
if len(p) > 0 {
buf.WriteByte(byte(len(p)))
buf.Write(p)
}
decoded = buf.Bytes()
}
encoded := make([]byte, base32Encoding.EncodedLen(len(decoded)))
base32Encoding.Encode(encoded, decoded)
encoded = bytes.ToLower(encoded)
labels := chunks(encoded, 63)
labels = append(labels, domain...)
name, err := NewName(labels)
if err != nil {
return nil, err
}
var id uint16
_ = binary.Read(rand.Reader, binary.BigEndian, &id)
query := &Message{
ID: id,
Flags: 0x0100,
Question: []Question{
{
Name: name,
Type: RRTypeTXT,
Class: ClassIN,
},
},
Additional: []RR{
{
Name: Name{},
Type: RRTypeOPT,
Class: 4096,
TTL: 0,
Data: []byte{},
},
},
}
buf, err := query.WireFormat()
if err != nil {
return nil, err
}
return buf, nil
}
func chunks(p []byte, n int) [][]byte {
var result [][]byte
for len(p) > 0 {
sz := len(p)
if sz > n {
sz = n
}
result = append(result, p[:sz])
p = p[sz:]
}
return result
}
func nextPacket(r *bytes.Reader) ([]byte, error) {
var n uint16
err := binary.Read(r, binary.BigEndian, &n)
if err != nil {
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(r, p)
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return p, err
}
func dnsResponsePayload(resp *Message, domain Name) []byte {
if resp.Flags&0x8000 != 0x8000 {
return nil
}
if resp.Flags&0x000f != RcodeNoError {
return nil
}
if len(resp.Answer) != 1 {
return nil
}
answer := resp.Answer[0]
_, ok := answer.Name.TrimSuffix(domain)
if !ok {
return nil
}
if answer.Type != RRTypeTXT {
return nil
}
payload, err := DecodeRDataTXT(answer.Data)
if err != nil {
return nil
}
return payload
}