From 2b4269962360bdb4abc9c1d0117d52a92162a4a9 Mon Sep 17 00:00:00 2001 From: LjhAUMEM Date: Fri, 29 May 2026 17:44:06 +0800 Subject: [PATCH] XICMP finalmask: Refactor & Speed up; Add multi `ips` and `dgram` mode (client) (#6168) https://github.com/XTLS/Xray-core/pull/5872#issuecomment-4198514120 https://github.com/XTLS/Xray-core/pull/6168#issuecomment-4571606666 https://github.com/XTLS/Xray-core/pull/6168#issuecomment-4573294637 Example: https://github.com/XTLS/Xray-core/pull/6168#issue-4484980927 Closes https://github.com/XTLS/Xray-core/discussions/5879 --- infra/conf/transport_internet.go | 17 +- transport/internet/finalmask/xicmp/client.go | 508 +++++++++--------- .../internet/finalmask/xicmp/config.pb.go | 24 +- .../internet/finalmask/xicmp/config.proto | 4 +- transport/internet/finalmask/xicmp/server.go | 492 ++++++++--------- 5 files changed, 520 insertions(+), 525 deletions(-) diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 69a90922..7ef91e52 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "math" + "net/netip" "net/url" "os" "regexp" @@ -1901,18 +1902,20 @@ func (c *Xdns) Build() (proto.Message, error) { } type Xicmp struct { - ListenIp string `json:"listenIp"` - Id uint16 `json:"id"` + DGRAM bool `json:"dgram"` + IPs []string `json:"ips"` } func (c *Xicmp) Build() (proto.Message, error) { - config := &xicmp.Config{ - Ip: c.ListenIp, - Id: int32(c.Id), + for _, ip := range c.IPs { + if _, err := netip.ParseAddr(ip); err != nil { + return nil, err + } } - if config.Ip == "" { - config.Ip = "0.0.0.0" + config := &xicmp.Config{ + DGRAM: c.DGRAM, + IPs: c.IPs, } return config, nil diff --git a/transport/internet/finalmask/xicmp/client.go b/transport/internet/finalmask/xicmp/client.go index d738b125..3777bd87 100644 --- a/transport/internet/finalmask/xicmp/client.go +++ b/transport/internet/finalmask/xicmp/client.go @@ -1,14 +1,21 @@ package xicmp import ( + "bytes" "context" + "crypto/rand" + "encoding/binary" + goerrors "errors" + "fmt" "io" + mathrand "math/rand" "net" - "strings" + "net/netip" "sync" "time" + _ "unsafe" - "github.com/xtls/xray-core/common/crypto" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/transport/internet/finalmask" "golang.org/x/net/icmp" @@ -16,149 +23,115 @@ import ( "golang.org/x/net/ipv6" ) -const ( - initPollDelay = 500 * time.Millisecond - maxPollDelay = 10 * time.Second - pollDelayMultiplier = 2.0 - pollLimit = 16 - windowSize = 1000 -) +var pool = sync.Pool{ + New: func() any { + return make([]byte, finalmask.UDPSize) + }, +} type packet struct { p []byte addr net.Addr -} - -type seqStatus struct { - needSeqByte bool - seqByte byte + err error } type xicmpConnClient struct { conn net.PacketConn - icmpConn *icmp.PacketConn - - typ icmp.Type - id int - seq int - proto int - seqStatus map[int]*seqStatus - - pollChan chan struct{} - readQueue chan *packet - writeQueue chan *packet - - closed bool - mutex sync.Mutex + icmp4 *icmp.PacketConn + icmp6 *icmp.PacketConn + udp bool + ips []netip.Addr + clientID [8]byte + id int + seq int + readCh chan packet + closedCh chan struct{} + mu sync.Mutex } func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { - network := "ip4:icmp" - typ := icmp.Type(ipv4.ICMPTypeEcho) - proto := 1 - if strings.Contains(c.Ip, ":") { - network = "ip6:ipv6-icmp" - typ = ipv6.ICMPTypeEchoRequest - proto = 58 + var icmp4, icmp6 *icmp.PacketConn + var err4, err6 error + if c.DGRAM { + icmp4, err4 = icmp.ListenPacket("udp4", "0.0.0.0") + icmp6, err6 = icmp.ListenPacket("udp6", "::") + } else { + icmp4, err4 = icmp.ListenPacket("ip4:icmp", "0.0.0.0") + icmp6, err6 = icmp.ListenPacket("ip6:ipv6-icmp", "::") + } + if err4 != nil || err6 != nil { + return nil, errors.Combine(err4, err6) } - icmpConn, err := icmp.ListenPacket(network, c.Ip) - if err != nil { - return nil, errors.New("xicmp listen err").Base(err) + ips := make([]netip.Addr, 0, len(c.IPs)) + for _, ip := range c.IPs { + ips = append(ips, netip.MustParseAddr(ip)) } - if c.Id == 0 { - c.Id = int32(crypto.RandBetween(0, 65535)) - } + var clientID [8]byte + common.Must2(rand.Read(clientID[:])) conn := &xicmpConnClient{ conn: raw, - icmpConn: icmpConn, - - typ: typ, - id: int(c.Id), - seq: 1, - proto: proto, - seqStatus: make(map[int]*seqStatus), - - pollChan: make(chan struct{}, pollLimit), - readQueue: make(chan *packet, 256), - writeQueue: make(chan *packet, 256), + icmp4: icmp4, + icmp6: icmp6, + udp: c.DGRAM, + ips: ips, + clientID: clientID, + id: mathrand.Intn(65536), + seq: 1, + readCh: make(chan packet), + closedCh: make(chan struct{}), } - go conn.recvLoop() - go conn.sendLoop() + go conn.recv4() + go conn.recv6() return conn, nil } -func (c *xicmpConnClient) encode(p []byte) ([]byte, error) { - c.mutex.Lock() - defer c.mutex.Unlock() - - needSeqByte := false - var seqByte byte - data := p - if len(p) > 0 { - needSeqByte = true - seqByte = p[0] - } - - msg := icmp.Message{ - Type: c.typ, - Code: 0, - Body: &icmp.Echo{ - ID: c.id, - Seq: c.seq, - Data: data, - }, - } - - buf, err := msg.Marshal(nil) - if err != nil { - return nil, err - } - - if len(buf) > finalmask.UDPSize { - return nil, errors.New("xicmp len(buf) > finalmask.UDPSize") - } - - c.seqStatus[c.seq] = &seqStatus{ - needSeqByte: needSeqByte, - seqByte: seqByte, - } - - delete(c.seqStatus, int(uint16(c.seq-windowSize))) - - c.seq++ - - if c.seq == 65536 { - delete(c.seqStatus, int(uint16(c.seq-windowSize))) - c.seq = 1 - } - - return buf, nil +func (c *xicmpConnClient) ring(a, b uint16) uint16 { + return min(a-b, b-a) } -func (c *xicmpConnClient) recvLoop() { - var buf [finalmask.UDPSize]byte +func (c *xicmpConnClient) closed() bool { + select { + case <-c.closedCh: + return true + default: + return false + } +} + +func (c *xicmpConnClient) recv4() { + var b [finalmask.UDPSize]byte for { - if c.closed { - break + if c.closed() { + return } - n, addr, err := c.icmpConn.ReadFrom(buf[:]) + n, addr, err := c.icmp4.ReadFrom(b[:]) + if err != nil { + var netErr net.Error + if goerrors.As(err, &netErr) && netErr.Timeout() { + select { + case c.readCh <- packet{ + err: err, + }: + case <-c.closedCh: + return + } + } + continue + } + + msg, err := icmp.ParseMessage(1, b[:n]) if err != nil { continue } - msg, err := icmp.ParseMessage(c.proto, buf[:n]) - if err != nil { - continue - } - - if msg.Type != ipv4.ICMPTypeEchoReply && msg.Type != ipv6.ICMPTypeEchoReply { + if msg.Type != ipv4.ICMPTypeEchoReply { continue } @@ -167,180 +140,223 @@ func (c *xicmpConnClient) recvLoop() { continue } - c.mutex.Lock() - seqStatus, ok := c.seqStatus[echo.Seq] - c.mutex.Unlock() + // errors.LogDebug(context.Background(), "id ", echo.ID, " seq ", echo.Seq, " addr ", addr) + if !c.udp && echo.ID != c.id { + continue + } + + if c.ring(uint16(echo.Seq), uint16(c.seq)) > 1000 { + continue + } + + if len(echo.Data) > 8 && bytes.Equal(echo.Data[:8], c.clientID[:]) { + continue + } + + p := pool.Get().([]byte)[:len(echo.Data)] + copy(p, echo.Data) + + if !c.udp { + addr = &net.UDPAddr{IP: addr.(*net.IPAddr).IP} + } + + select { + case c.readCh <- packet{ + p: p, + addr: addr, + }: + case <-c.closedCh: + pool.Put(p) + return + } + } +} + +func (c *xicmpConnClient) recv6() { + var b [finalmask.UDPSize]byte + + for { + if c.closed() { + break + } + + n, addr, err := c.icmp6.ReadFrom(b[:]) + if err != nil { + var netErr net.Error + if goerrors.As(err, &netErr) && netErr.Timeout() { + select { + case c.readCh <- packet{ + err: err, + }: + case <-c.closedCh: + return + } + } + continue + } + + msg, err := icmp.ParseMessage(58, b[:n]) + if err != nil { + continue + } + + if msg.Type != ipv6.ICMPTypeEchoReply { + continue + } + + echo, ok := msg.Body.(*icmp.Echo) if !ok { continue } - if seqStatus.needSeqByte { - if len(echo.Data) <= 1 { - continue - } - if echo.Data[0] == seqStatus.seqByte { - continue - } - echo.Data = echo.Data[1:] + // errors.LogDebug(context.Background(), "id ", echo.ID, " seq ", echo.Seq, " addr ", addr) + + if !c.udp && echo.ID != c.id { + continue } - if len(echo.Data) > 0 { - c.mutex.Lock() - delete(c.seqStatus, echo.Seq) - c.mutex.Unlock() - - buf := make([]byte, len(echo.Data)) - copy(buf, echo.Data) - select { - case c.readQueue <- &packet{ - p: buf, - addr: &net.UDPAddr{IP: addr.(*net.IPAddr).IP}, - }: - default: - errors.LogDebug(context.Background(), addr, " ", echo.Seq, " ", echo.ID, " mask read err queue full") - } - - select { - case c.pollChan <- struct{}{}: - default: - } + if c.ring(uint16(echo.Seq), uint16(c.seq)) > 1000 { + continue } - } - errors.LogDebug(context.Background(), "xicmp closed") + if len(echo.Data) > 8 && bytes.Equal(echo.Data[:8], c.clientID[:]) { + continue + } - close(c.pollChan) - close(c.readQueue) + p := pool.Get().([]byte)[:len(echo.Data)] + copy(p, echo.Data) - c.mutex.Lock() - defer c.mutex.Unlock() - - c.closed = true - close(c.writeQueue) -} - -func (c *xicmpConnClient) sendLoop() { - var addr net.Addr - - pollDelay := initPollDelay - pollTimer := time.NewTimer(pollDelay) - for { - var p *packet - pollTimerExpired := false + if !c.udp { + addr = &net.UDPAddr{IP: addr.(*net.IPAddr).IP} + } select { - case p = <-c.writeQueue: - default: - select { - case p = <-c.writeQueue: - case <-c.pollChan: - case <-pollTimer.C: - pollTimerExpired = true - } - } - - if p != nil { - addr = p.addr - - select { - case <-c.pollChan: - default: - } - } else if addr != nil { - encoded, _ := c.encode(nil) - p = &packet{ - p: encoded, - addr: addr, - } - } - - if pollTimerExpired { - pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier) - if pollDelay > maxPollDelay { - pollDelay = maxPollDelay - } - } else { - if !pollTimer.Stop() { - <-pollTimer.C - } - pollDelay = initPollDelay - } - pollTimer.Reset(pollDelay) - - if c.closed { + case c.readCh <- packet{ + p: p, + addr: addr, + }: + case <-c.closedCh: + pool.Put(p) return } - - if p != nil { - _, err := c.icmpConn.WriteTo(p.p, p.addr) - if err != nil { - errors.LogDebug(context.Background(), p.addr, " xicmp writeto err ", err) - } - } } } func (c *xicmpConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - packet, ok := <-c.readQueue - if !ok { - return 0, nil, net.ErrClosed + select { + case packet := <-c.readCh: + if packet.p != nil { + n = copy(p, packet.p) + pool.Put(packet.p) + } + return n, packet.addr, packet.err + case <-c.closedCh: + return 0, nil, io.EOF } - if len(p) < len(packet.p) { - errors.LogDebug(context.Background(), packet.addr, " mask read err short buffer ", len(p), " ", len(packet.p)) - return 0, packet.addr, nil - } - copy(p, packet.p) - return len(packet.p), packet.addr, nil } func (c *xicmpConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { - encoded, err := c.encode(p) + if len(p)+16 > finalmask.UDPSize { + errors.LogError(context.Background(), "drop packet to ", addr, " with size ", len(p)) + return 0, nil + } + + c.mu.Lock() + seq := c.seq + c.seq += 1 + c.seq %= 65536 + c.mu.Unlock() + + ip := addr.(*net.UDPAddr).IP + if len(c.ips) > 0 { + ip = c.ips[mathrand.Intn(len(c.ips))].AsSlice() + } + + if c.udp { + addr = &net.UDPAddr{IP: ip} + } else { + addr = &net.IPAddr{IP: ip} + } + + b := pool.Get().([]byte)[:finalmask.UDPSize] + defer pool.Put(b) + + copy(b[8:], c.clientID[:]) + copy(b[16:], p) + + if ip.To4() != nil { + b = marshal(b, ipv4.ICMPTypeEcho, c.id, seq, 8+len(p)) + _, err = c.icmp4.WriteTo(b, addr) + } else { + b = marshal(b, ipv6.ICMPTypeEchoRequest, c.id, seq, 8+len(p)) + _, err = c.icmp6.WriteTo(b, addr) + } + if err != nil { - errors.LogDebug(context.Background(), addr, " xicmp wireformat err ", err) - return 0, nil + errors.LogErrorInner(context.Background(), err, "xicmp write") + return 0, err } - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.closed { - return 0, io.ErrClosedPipe - } - - select { - case c.writeQueue <- &packet{ - p: encoded, - addr: &net.IPAddr{IP: addr.(*net.UDPAddr).IP}, - }: - return len(p), nil - default: - errors.LogDebug(context.Background(), addr, " mask write err queue full") - return 0, nil - } + return len(p), nil } func (c *xicmpConnClient) Close() error { - c.closed = true - _ = c.icmpConn.Close() - return c.conn.Close() + c.mu.Lock() + defer c.mu.Unlock() + if c.closed() { + return nil + } + close(c.closedCh) + _ = c.icmp4.Close() + _ = c.icmp6.Close() + _ = c.conn.Close() + return nil } func (c *xicmpConnClient) LocalAddr() net.Addr { - return &net.UDPAddr{ - IP: c.icmpConn.LocalAddr().(*net.IPAddr).IP, - Port: c.id, - } + return c.conn.LocalAddr() } func (c *xicmpConnClient) SetDeadline(t time.Time) error { - return c.icmpConn.SetDeadline(t) + _ = c.icmp4.SetDeadline(t) + _ = c.icmp6.SetDeadline(t) + return nil } func (c *xicmpConnClient) SetReadDeadline(t time.Time) error { - return c.icmpConn.SetReadDeadline(t) + _ = c.icmp4.SetReadDeadline(t) + _ = c.icmp6.SetReadDeadline(t) + return nil } func (c *xicmpConnClient) SetWriteDeadline(t time.Time) error { - return c.icmpConn.SetWriteDeadline(t) + _ = c.icmp4.SetWriteDeadline(t) + _ = c.icmp6.SetWriteDeadline(t) + return nil +} + +//go:linkname checksum golang.org/x/net/icmp.checksum +func checksum(b []byte) uint16 + +func marshal(b []byte, typ icmp.Type, id, seq int, dataLen int) []byte { + is4 := false + switch typ := typ.(type) { + case ipv4.ICMPType: + is4 = true + b[0] = byte(typ) + case ipv6.ICMPType: + b[0] = byte(typ) + default: + panic(fmt.Sprintf("%T %v", typ, typ)) + } + clear(b[1:4]) + binary.BigEndian.PutUint16(b[4:], uint16(id)) + binary.BigEndian.PutUint16(b[6:], uint16(seq)) + if is4 { + s := checksum(b[:8+dataLen]) + b[2] ^= byte(s) + b[3] ^= byte(s >> 8) + } + return b[:8+dataLen] } diff --git a/transport/internet/finalmask/xicmp/config.pb.go b/transport/internet/finalmask/xicmp/config.pb.go index be75dfe5..0f88ef1e 100644 --- a/transport/internet/finalmask/xicmp/config.pb.go +++ b/transport/internet/finalmask/xicmp/config.pb.go @@ -23,8 +23,8 @@ const ( type Config struct { state protoimpl.MessageState `protogen:"open.v1"` - Ip string `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"` - Id int32 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` + DGRAM bool `protobuf:"varint,1,opt,name=DGRAM,proto3" json:"DGRAM,omitempty"` + IPs []string `protobuf:"bytes,2,rep,name=IPs,proto3" json:"IPs,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -59,28 +59,28 @@ func (*Config) Descriptor() ([]byte, []int) { return file_transport_internet_finalmask_xicmp_config_proto_rawDescGZIP(), []int{0} } -func (x *Config) GetIp() string { +func (x *Config) GetDGRAM() bool { if x != nil { - return x.Ip + return x.DGRAM } - return "" + return false } -func (x *Config) GetId() int32 { +func (x *Config) GetIPs() []string { if x != nil { - return x.Id + return x.IPs } - return 0 + return nil } var File_transport_internet_finalmask_xicmp_config_proto protoreflect.FileDescriptor const file_transport_internet_finalmask_xicmp_config_proto_rawDesc = "" + "\n" + - "/transport/internet/finalmask/xicmp/config.proto\x12'xray.transport.internet.finalmask.xicmp\"(\n" + - "\x06Config\x12\x0e\n" + - "\x02ip\x18\x01 \x01(\tR\x02ip\x12\x0e\n" + - "\x02id\x18\x02 \x01(\x05R\x02idB\x97\x01\n" + + "/transport/internet/finalmask/xicmp/config.proto\x12'xray.transport.internet.finalmask.xicmp\"0\n" + + "\x06Config\x12\x14\n" + + "\x05DGRAM\x18\x01 \x01(\bR\x05DGRAM\x12\x10\n" + + "\x03IPs\x18\x02 \x03(\tR\x03IPsB\x97\x01\n" + "+com.xray.transport.internet.finalmask.xicmpP\x01Z= idleTimeout { - close(q.queue) - delete(c.writeQueueMap, key) - } - } - +func (c *xicmpConnServer) closed() bool { + select { + case <-c.closedCh: + return true + default: return false } +} +func (c *xicmpConnServer) clean() { + ticker := time.NewTicker(time.Minute / 2) + defer ticker.Stop() for { - time.Sleep(idleTimeout / 2) - if f() { + select { + case <-ticker.C: + now := time.Now() + c.mu.Lock() + for key, r := range c.rec { + if now.Sub(r.last) > time.Minute { + delete(c.rec, key) + } + } + c.mu.Unlock() + case <-c.closedCh: return } } } -func (c *xicmpConnServer) ensureQueue(addr net.Addr) *queue { - if c.closed { - return nil - } - - q, ok := c.writeQueueMap[addr.String()] - if !ok { - q = &queue{ - queue: make(chan []byte, 512), - } - c.writeQueueMap[addr.String()] = q - } - q.last = time.Now() - - return q -} - -func (c *xicmpConnServer) encode(p []byte, id int, seq int, needSeqByte bool, seqByte byte) ([]byte, error) { - data := p - if needSeqByte { - b2 := c.randUntil(seqByte) - data = append([]byte{b2}, p...) - } - - msg := icmp.Message{ - Type: c.typ, - Code: 0, - Body: &icmp.Echo{ - ID: id, - Seq: seq, - Data: data, - }, - } - - buf, err := msg.Marshal(nil) - if err != nil { - return nil, err - } - - if len(buf) > finalmask.UDPSize { - return nil, errors.New("xicmp len(buf) > finalmask.UDPSize") - } - - return buf, nil -} - -func (c *xicmpConnServer) randUntil(b1 byte) byte { - b2 := byte(crypto.RandBetween(0, 255)) - for { - if b2 != b1 { - return b2 - } - b2 = byte(crypto.RandBetween(0, 255)) - } -} - -func (c *xicmpConnServer) recvLoop() { - var buf [finalmask.UDPSize]byte +func (c *xicmpConnServer) recv4() { + var b [finalmask.UDPSize]byte for { - if c.closed { - break + if c.closed() { + return } - n, addr, err := c.icmpConn.ReadFrom(buf[:]) + n, addr, err := c.icmp4.ReadFrom(b[:]) + if err != nil { + var netErr net.Error + if goerrors.As(err, &netErr) && netErr.Timeout() { + select { + case c.readCh <- packet{ + err: err, + }: + case <-c.closedCh: + return + } + } + continue + } + + msg, err := icmp.ParseMessage(1, b[:n]) if err != nil { continue } - msg, err := icmp.ParseMessage(c.proto, buf[:n]) - if err != nil { - continue - } - - if msg.Type != ipv4.ICMPTypeEcho && msg.Type != ipv6.ICMPTypeEchoRequest { + if msg.Type != ipv4.ICMPTypeEcho { continue } @@ -197,179 +143,209 @@ func (c *xicmpConnServer) recvLoop() { continue } - if c.config.Id != 0 && echo.ID != int(c.config.Id) { + if len(echo.Data) <= 8 { continue } - needSeqByte := false - var seqByte byte + if len(c.ips) > 0 { + netipAddr, ok := netip.AddrFromSlice(addr.(*net.IPAddr).IP) + if !ok { + continue + } - if len(echo.Data) > 0 { - needSeqByte = true - seqByte = echo.Data[0] - - buf := make([]byte, len(echo.Data)) - copy(buf, echo.Data) - select { - case c.readQueue <- &packet{ - p: buf, - addr: &net.UDPAddr{ - IP: addr.(*net.IPAddr).IP, - Port: echo.ID, - }, - }: - default: - errors.LogDebug(context.Background(), addr, " ", echo.ID, " ", echo.Seq, " mask read err queue full") + if _, ok := c.ips[netipAddr]; !ok { + continue } } - select { - case c.ch <- &record{ - id: echo.ID, - seq: echo.Seq, - needSeqByte: needSeqByte, - seqByte: seqByte, - addr: &net.UDPAddr{ - IP: addr.(*net.IPAddr).IP, - Port: echo.ID, - }, - }: - default: - errors.LogDebug(context.Background(), addr, " ", echo.ID, " ", echo.Seq, " mask read err record queue full") + cAddr := clientIDToAddr([8]byte(echo.Data[:8])) + + c.mu.Lock() + c.rec[cAddr.String()] = record{ + id: echo.ID, + seq: echo.Seq, + addr: addr, + last: time.Now(), } - } + c.mu.Unlock() - errors.LogDebug(context.Background(), "xicmp closed") + p := pool.Get().([]byte)[:len(echo.Data[8:])] + copy(p, echo.Data[8:]) - close(c.ch) - close(c.readQueue) - - c.mutex.Lock() - defer c.mutex.Unlock() - - c.closed = true - for key, q := range c.writeQueueMap { - close(q.queue) - delete(c.writeQueueMap, key) + select { + case c.readCh <- packet{ + p: p, + addr: cAddr, + }: + case <-c.closedCh: + pool.Put(p) + return + } } } -func (c *xicmpConnServer) sendLoop() { - var nextRec *record - for { - rec := nextRec - nextRec = nil +func (c *xicmpConnServer) recv6() { + var b [finalmask.UDPSize]byte - if rec == nil { - var ok bool - rec, ok = <-c.ch + for { + if c.closed() { + return + } + + n, addr, err := c.icmp6.ReadFrom(b[:]) + if err != nil { + var netErr net.Error + if goerrors.As(err, &netErr) && netErr.Timeout() { + select { + case c.readCh <- packet{ + err: err, + }: + case <-c.closedCh: + return + } + } + continue + } + + msg, err := icmp.ParseMessage(58, b[:n]) + if err != nil { + continue + } + + if msg.Type != ipv6.ICMPTypeEchoRequest { + continue + } + + echo, ok := msg.Body.(*icmp.Echo) + if !ok { + continue + } + + if len(echo.Data) <= 8 { + continue + } + + if len(c.ips) > 0 { + netipAddr, ok := netip.AddrFromSlice(addr.(*net.IPAddr).IP) if !ok { - break + continue + } + + if _, ok := c.ips[netipAddr]; !ok { + continue } } - c.mutex.Lock() - q := c.ensureQueue(rec.addr) - if q == nil { - c.mutex.Unlock() - return + cAddr := clientIDToAddr([8]byte(echo.Data[:8])) + + c.mu.Lock() + c.rec[cAddr.String()] = record{ + id: echo.ID, + seq: echo.Seq, + addr: addr, + last: time.Now(), } - c.mutex.Unlock() + c.mu.Unlock() - var p []byte - - timer := time.NewTimer(maxResponseDelay) + p := pool.Get().([]byte)[:len(echo.Data[8:])] + copy(p, echo.Data[8:]) select { - case p = <-q.queue: - default: - select { - case p = <-q.queue: - case <-timer.C: - case nextRec = <-c.ch: - } - } - - timer.Stop() - - if len(p) == 0 { - continue - } - - buf, err := c.encode(p, rec.id, rec.seq, rec.needSeqByte, rec.seqByte) - if err != nil { - errors.LogDebug(context.Background(), rec.addr, " ", rec.id, " ", rec.seq, " xicmp wireformat err ", err) - continue - } - - if c.closed { + case c.readCh <- packet{ + p: p, + addr: cAddr, + }: + case <-c.closedCh: + pool.Put(p) return } - - _, err = c.icmpConn.WriteTo(buf, &net.IPAddr{IP: rec.addr.(*net.UDPAddr).IP}) - if err != nil { - errors.LogDebug(context.Background(), rec.addr, " ", rec.id, " ", rec.seq, " xicmp writeto err ", err) - } } } func (c *xicmpConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - packet, ok := <-c.readQueue - if !ok { - return 0, nil, net.ErrClosed + select { + case packet := <-c.readCh: + if packet.p != nil { + n = copy(p, packet.p) + pool.Put(packet.p) + } + return n, packet.addr, packet.err + case <-c.closedCh: + return 0, nil, io.EOF } - if len(p) < len(packet.p) { - errors.LogDebug(context.Background(), packet.addr, " mask read err short buffer ", len(p), " ", len(packet.p)) - return 0, packet.addr, nil - } - copy(p, packet.p) - return len(packet.p), packet.addr, nil } func (c *xicmpConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { - if len(p)+8+1 > finalmask.UDPSize { - errors.LogDebug(context.Background(), addr, " mask write err short write ", len(p), "+8+1 > ", finalmask.UDPSize) + if len(p)+8 > finalmask.UDPSize { + errors.LogError(context.Background(), "drop packet to ", addr, " with size ", len(p)) return 0, nil } - c.mutex.Lock() - defer c.mutex.Unlock() - - q := c.ensureQueue(addr) - if q == nil { - return 0, io.ErrClosedPipe - } - - buf := make([]byte, len(p)) - copy(buf, p) - - select { - case q.queue <- buf: - return len(p), nil - default: - // errors.LogDebug(context.Background(), addr, " mask write err queue full") + c.mu.Lock() + r, ok := c.rec[addr.String()] + if !ok { + errors.LogError(context.Background(), "drop packet to ", addr, " with size ", len(p)) + c.mu.Unlock() return 0, nil } + r.last = time.Now() + c.rec[addr.String()] = r + c.mu.Unlock() + + // errors.LogDebug(context.Background(), "id ", r.id, " seq ", r.seq, " addr ", r.addr) + + b := pool.Get().([]byte)[:finalmask.UDPSize] + defer pool.Put(b) + + copy(b[8:], p) + + if r.addr.(*net.IPAddr).IP.To4() != nil { + b = marshal(b, ipv4.ICMPTypeEchoReply, r.id, r.seq, len(p)) + _, err = c.icmp4.WriteTo(b, r.addr) + } else { + b = marshal(b, ipv6.ICMPTypeEchoReply, r.id, r.seq, len(p)) + _, err = c.icmp6.WriteTo(b, r.addr) + } + + if err != nil { + errors.LogErrorInner(context.Background(), err, "xicmp write") + return 0, err + } + + return len(p), nil } func (c *xicmpConnServer) Close() error { - c.closed = true - _ = c.icmpConn.Close() - return c.conn.Close() + c.mu.Lock() + defer c.mu.Unlock() + if c.closed() { + return nil + } + close(c.closedCh) + _ = c.icmp4.Close() + _ = c.icmp6.Close() + _ = c.conn.Close() + return nil } func (c *xicmpConnServer) LocalAddr() net.Addr { - return &net.UDPAddr{IP: c.icmpConn.LocalAddr().(*net.IPAddr).IP} + return c.conn.LocalAddr() } func (c *xicmpConnServer) SetDeadline(t time.Time) error { - return c.icmpConn.SetDeadline(t) + _ = c.icmp4.SetDeadline(t) + _ = c.icmp6.SetDeadline(t) + return nil } func (c *xicmpConnServer) SetReadDeadline(t time.Time) error { - return c.icmpConn.SetReadDeadline(t) + _ = c.icmp4.SetReadDeadline(t) + _ = c.icmp6.SetReadDeadline(t) + return nil } func (c *xicmpConnServer) SetWriteDeadline(t time.Time) error { - return c.icmpConn.SetWriteDeadline(t) + _ = c.icmp4.SetWriteDeadline(t) + _ = c.icmp6.SetWriteDeadline(t) + return nil }