diff --git a/proxy/wireguard/bind.go b/proxy/wireguard/bind.go index 515afaa5..53f028c7 100644 --- a/proxy/wireguard/bind.go +++ b/proxy/wireguard/bind.go @@ -25,13 +25,23 @@ type netReadInfo struct { err error } +// receivedPacket represents a packet received from a peer connection +type receivedPacket struct { + data []byte + endpoint conn.Endpoint + err error +} + // reduce duplicated code type netBind struct { dns dns.Client dnsOption dns.IPOption - workers int - readQueue chan *netReadInfo + workers int + readQueue chan *netReadInfo + packetQueue chan *receivedPacket + startedMutex sync.Mutex + started bool } // SetMark implements conn.Bind @@ -80,6 +90,35 @@ func (bind *netBind) BatchSize() int { // Open implements conn.Bind func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { bind.readQueue = make(chan *netReadInfo) + bind.packetQueue = make(chan *receivedPacket, 100) + + // Start a dispatcher goroutine that matches readQueue requests with received packets + bind.startedMutex.Lock() + if !bind.started { + bind.started = true + go func() { + for { + packet, ok := <-bind.packetQueue + if !ok { + return + } + + // Wait for a read request from WireGuard + request, ok := <-bind.readQueue + if !ok { + return + } + + // Copy packet data to the request buffer + n := copy(request.buff, packet.data) + request.bytes = n + request.endpoint = packet.endpoint + request.err = packet.err + request.waiter.Done() + } + }() + } + bind.startedMutex.Unlock() fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { defer func() { @@ -115,6 +154,9 @@ func (bind *netBind) Close() error { if bind.readQueue != nil { close(bind.readQueue) } + if bind.packetQueue != nil { + close(bind.packetQueue) + } return nil } @@ -133,30 +175,43 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { } endpoint.conn = c - go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) { + // Start a goroutine that continuously reads from this connection + // and sends received packets to the packet queue + go func(conn net.Conn, endpoint *netEndpoint) { + const maxPacketSize = 1500 for { - v, ok := <-readQueue - if !ok { + buf := make([]byte, maxPacketSize) + n, err := conn.Read(buf) + + if n > 3 { + // Clear reserved bytes + buf[1] = 0 + buf[2] = 0 + buf[3] = 0 + } + + packet := &receivedPacket{ + data: buf[:n], + endpoint: endpoint, + err: err, + } + + // Try to send packet to queue; if queue is full or closed, exit + select { + case bind.packetQueue <- packet: + // Packet sent successfully + default: + // Queue is full or closed, exit goroutine + endpoint.conn = nil return } - i, err := c.Read(v.buff) - - if i > 3 { - v.buff[1] = 0 - v.buff[2] = 0 - v.buff[3] = 0 - } - - v.bytes = i - v.endpoint = endpoint - v.err = err - v.waiter.Done() + if err != nil { endpoint.conn = nil return } } - }(bind.readQueue, endpoint) + }(c, endpoint) return nil }