mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-07-03 10:18:42 +00:00
363 lines
6.4 KiB
Go
363 lines
6.4 KiB
Go
package xdns
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base32"
|
|
"encoding/binary"
|
|
go_errors "errors"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/xtls/xray-core/common"
|
|
"github.com/xtls/xray-core/common/errors"
|
|
"github.com/xtls/xray-core/transport/internet/finalmask"
|
|
)
|
|
|
|
const (
|
|
numPadding = 3
|
|
numPaddingForPoll = 8
|
|
initPollDelay = 500 * time.Millisecond
|
|
maxPollDelay = 10 * time.Second
|
|
pollDelayMultiplier = 2.0
|
|
pollLimit = 16
|
|
)
|
|
|
|
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
|
|
|
|
type packet struct {
|
|
p []byte
|
|
addr net.Addr
|
|
}
|
|
|
|
type xdnsConnClient struct {
|
|
net.PacketConn
|
|
|
|
clientID []byte
|
|
domain Name
|
|
|
|
pollChan chan struct{}
|
|
readQueue chan *packet
|
|
writeQueue chan *packet
|
|
|
|
closed bool
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
|
|
domain, err := ParseName(c.Domain)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conn := &xdnsConnClient{
|
|
PacketConn: raw,
|
|
|
|
clientID: make([]byte, 8),
|
|
domain: domain,
|
|
|
|
pollChan: make(chan struct{}, pollLimit),
|
|
readQueue: make(chan *packet, 256),
|
|
writeQueue: make(chan *packet, 256),
|
|
}
|
|
|
|
common.Must2(rand.Read(conn.clientID))
|
|
|
|
go conn.recvLoop()
|
|
go conn.sendLoop()
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func (c *xdnsConnClient) recvLoop() {
|
|
var buf [finalmask.UDPSize]byte
|
|
|
|
for {
|
|
if c.closed {
|
|
break
|
|
}
|
|
|
|
n, addr, err := c.PacketConn.ReadFrom(buf[:])
|
|
if err != nil || n == 0 {
|
|
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
|
|
resp, err := MessageFromWireFormat(buf[:n])
|
|
if err != nil {
|
|
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
|
|
continue
|
|
}
|
|
|
|
payload := dnsResponsePayload(&resp, c.domain)
|
|
|
|
r := bytes.NewReader(payload)
|
|
anyPacket := false
|
|
for {
|
|
p, err := nextPacket(r)
|
|
if err != nil {
|
|
break
|
|
}
|
|
anyPacket = true
|
|
|
|
buf := make([]byte, len(p))
|
|
copy(buf, p)
|
|
select {
|
|
case c.readQueue <- &packet{
|
|
p: buf,
|
|
addr: addr,
|
|
}:
|
|
default:
|
|
errors.LogDebug(context.Background(), addr, " mask read err queue full")
|
|
}
|
|
}
|
|
|
|
if anyPacket {
|
|
select {
|
|
case c.pollChan <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
errors.LogDebug(context.Background(), "xdns closed")
|
|
|
|
close(c.pollChan)
|
|
close(c.readQueue)
|
|
|
|
c.mutex.Lock()
|
|
defer c.mutex.Unlock()
|
|
|
|
c.closed = true
|
|
close(c.writeQueue)
|
|
}
|
|
|
|
func (c *xdnsConnClient) sendLoop() {
|
|
var addr net.Addr
|
|
|
|
pollDelay := initPollDelay
|
|
pollTimer := time.NewTimer(pollDelay)
|
|
for {
|
|
var p *packet
|
|
pollTimerExpired := false
|
|
|
|
select {
|
|
case p = <-c.writeQueue:
|
|
default:
|
|
select {
|
|
case p = <-c.writeQueue:
|
|
case <-c.pollChan:
|
|
case <-pollTimer.C:
|
|
pollTimerExpired = true
|
|
}
|
|
}
|
|
|
|
if p != nil {
|
|
addr = p.addr
|
|
|
|
select {
|
|
case <-c.pollChan:
|
|
default:
|
|
}
|
|
} else if addr != nil {
|
|
encoded, _ := encode(nil, c.clientID, c.domain)
|
|
p = &packet{
|
|
p: encoded,
|
|
addr: addr,
|
|
}
|
|
}
|
|
|
|
if pollTimerExpired {
|
|
pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier)
|
|
if pollDelay > maxPollDelay {
|
|
pollDelay = maxPollDelay
|
|
}
|
|
} else {
|
|
if !pollTimer.Stop() {
|
|
<-pollTimer.C
|
|
}
|
|
pollDelay = initPollDelay
|
|
}
|
|
pollTimer.Reset(pollDelay)
|
|
|
|
if c.closed {
|
|
return
|
|
}
|
|
|
|
if p != nil {
|
|
_, err := c.PacketConn.WriteTo(p.p, p.addr)
|
|
if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) {
|
|
c.closed = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *xdnsConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
packet, ok := <-c.readQueue
|
|
if !ok {
|
|
return 0, nil, net.ErrClosed
|
|
}
|
|
if len(p) < len(packet.p) {
|
|
errors.LogDebug(context.Background(), packet.addr, " mask read err short buffer ", len(p), " ", len(packet.p))
|
|
return 0, packet.addr, nil
|
|
}
|
|
copy(p, packet.p)
|
|
return len(packet.p), packet.addr, nil
|
|
}
|
|
|
|
func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
c.mutex.Lock()
|
|
defer c.mutex.Unlock()
|
|
|
|
if c.closed {
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
|
|
encoded, err := encode(p, c.clientID, c.domain)
|
|
if err != nil {
|
|
errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p))
|
|
return 0, nil
|
|
}
|
|
|
|
select {
|
|
case c.writeQueue <- &packet{
|
|
p: encoded,
|
|
addr: addr,
|
|
}:
|
|
return len(p), nil
|
|
default:
|
|
errors.LogDebug(context.Background(), addr, " mask write err queue full")
|
|
return 0, nil
|
|
}
|
|
}
|
|
|
|
func (c *xdnsConnClient) Close() error {
|
|
c.closed = true
|
|
return c.PacketConn.Close()
|
|
}
|
|
|
|
func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
|
|
var decoded []byte
|
|
{
|
|
if len(p) >= 224 {
|
|
return nil, errors.New("too long")
|
|
}
|
|
var buf bytes.Buffer
|
|
buf.Write(clientID[:])
|
|
n := numPadding
|
|
if len(p) == 0 {
|
|
n = numPaddingForPoll
|
|
}
|
|
buf.WriteByte(byte(224 + n))
|
|
_, _ = io.CopyN(&buf, rand.Reader, int64(n))
|
|
if len(p) > 0 {
|
|
buf.WriteByte(byte(len(p)))
|
|
buf.Write(p)
|
|
}
|
|
decoded = buf.Bytes()
|
|
}
|
|
|
|
encoded := make([]byte, base32Encoding.EncodedLen(len(decoded)))
|
|
base32Encoding.Encode(encoded, decoded)
|
|
encoded = bytes.ToLower(encoded)
|
|
labels := chunks(encoded, 63)
|
|
labels = append(labels, domain...)
|
|
name, err := NewName(labels)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var id uint16
|
|
_ = binary.Read(rand.Reader, binary.BigEndian, &id)
|
|
query := &Message{
|
|
ID: id,
|
|
Flags: 0x0100,
|
|
Question: []Question{
|
|
{
|
|
Name: name,
|
|
Type: RRTypeTXT,
|
|
Class: ClassIN,
|
|
},
|
|
},
|
|
Additional: []RR{
|
|
{
|
|
Name: Name{},
|
|
Type: RRTypeOPT,
|
|
Class: 4096,
|
|
TTL: 0,
|
|
Data: []byte{},
|
|
},
|
|
},
|
|
}
|
|
|
|
buf, err := query.WireFormat()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func chunks(p []byte, n int) [][]byte {
|
|
var result [][]byte
|
|
for len(p) > 0 {
|
|
sz := len(p)
|
|
if sz > n {
|
|
sz = n
|
|
}
|
|
result = append(result, p[:sz])
|
|
p = p[sz:]
|
|
}
|
|
return result
|
|
}
|
|
|
|
func nextPacket(r *bytes.Reader) ([]byte, error) {
|
|
var n uint16
|
|
err := binary.Read(r, binary.BigEndian, &n)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p := make([]byte, n)
|
|
_, err = io.ReadFull(r, p)
|
|
if err == io.EOF {
|
|
err = io.ErrUnexpectedEOF
|
|
}
|
|
return p, err
|
|
}
|
|
|
|
func dnsResponsePayload(resp *Message, domain Name) []byte {
|
|
if resp.Flags&0x8000 != 0x8000 {
|
|
return nil
|
|
}
|
|
if resp.Flags&0x000f != RcodeNoError {
|
|
return nil
|
|
}
|
|
|
|
if len(resp.Answer) != 1 {
|
|
return nil
|
|
}
|
|
answer := resp.Answer[0]
|
|
|
|
_, ok := answer.Name.TrimSuffix(domain)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if answer.Type != RRTypeTXT {
|
|
return nil
|
|
}
|
|
payload, err := DecodeRDataTXT(answer.Data)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return payload
|
|
}
|