diff --git a/common/buf/copy.go b/common/buf/copy.go index 4cc3be88..55881ebd 100644 --- a/common/buf/copy.go +++ b/common/buf/copy.go @@ -2,6 +2,7 @@ package buf import ( "io" + "sync" "time" "github.com/xtls/xray-core/common/errors" @@ -113,7 +114,12 @@ func Copy(reader Reader, writer Writer, options ...CopyOption) error { for _, option := range options { option(&handler) } - err := copyInternal(reader, writer, &handler) + var err error + if sReader, ok := reader.(*SingleReader); ok && false { + err = copyV(sReader, writer, &handler) + } else { + err = copyInternal(reader, writer, &handler) + } if err != nil && errors.Cause(err) != io.EOF { return err } @@ -133,3 +139,85 @@ func CopyOnceTimeout(reader Reader, writer Writer, timeout time.Duration) error } return writer.WriteMultiBuffer(mb) } + +func copyV(r *SingleReader, w Writer, handler *copyHandler) error { + // max packet len is 8192, so buffer channel size is 512, about 4MB memory usage + cache := make(chan *Buffer, 512) + stopRead := make(chan struct{}) + var rErr error + var wErr error + wg := sync.WaitGroup{} + wg.Add(2) + // downlink + go func() { + defer wg.Done() + defer close(cache) + for { + b, err := r.readBuffer() + if err == nil { + select { + case cache <- b: + // must be write error + case <-stopRead: + b.Release() + return + } + } else { + rErr = err + select { + case cache <- b: + case <-stopRead: + b.Release() + } + return + } + } + }() + // uplink + go func() { + defer wg.Done() + for { + b, ok := <-cache + if !ok { + return + } + var buffers = []*Buffer{b} + for stop := false; !stop; { + select { + case b, ok := <-cache: + if !ok { + stop = true + continue + } + buffers = append(buffers, b) + default: + stop = true + } + } + mb := MultiBuffer(buffers) + err := w.WriteMultiBuffer(mb) + for _, handler := range handler.onData { + handler(mb) + } + ReleaseMulti(mb) + if err != nil { + wErr = err + close(stopRead) + return + } + } + }() + wg.Wait() + for range cache { + // drain cache + b := <-cache + b.Release() + } + if wErr != nil { + return writeError{wErr} + } + if rErr != nil { + return readError{rErr} + } + return nil +} diff --git a/common/buf/reader.go b/common/buf/reader.go index 33d362d4..ca00043b 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -159,6 +159,11 @@ func (r *SingleReader) ReadMultiBuffer() (MultiBuffer, error) { return MultiBuffer{b}, err } +func (r *SingleReader) readBuffer() (*Buffer, error) { + b, err := ReadBuffer(r.Reader) + return b, err +} + // PacketReader is a Reader that read one Buffer every time. type PacketReader struct { io.Reader