diff --git a/transport/internet/splithttp/client.go b/transport/internet/splithttp/client.go index a736d01d7..0e3a73f2e 100644 --- a/transport/internet/splithttp/client.go +++ b/transport/internet/splithttp/client.go @@ -9,6 +9,7 @@ import ( "net/http/httptrace" "sync" + "github.com/apernet/quic-go/http3" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" @@ -176,6 +177,15 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio return nil } +// HTTP/1.1 and HTTP/2 will close itself, we only handle HTTP/3 here +func (c *DefaultDialerClient) Close() error { + transport := c.client.Transport + if h3Transport, ok := transport.(*http3.Transport); ok { + h3Transport.Close() + } + return nil +} + type WaitReadCloser struct { Wait chan struct{} io.ReadCloser diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 0f1f8c59b..817d93552 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -259,6 +259,7 @@ func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStrea if err != nil { return nil, err } + context.AfterFunc(conn.Context(), func() { pktConn.Close() }) switch quicParams.Congestion { case "reno": @@ -425,10 +426,10 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } if xmuxClient != nil { - xmuxClient.OpenUsage.Add(1) + xmuxClient.AddRunning() } if xmuxClient2 != nil && xmuxClient2 != xmuxClient { - xmuxClient2.OpenUsage.Add(1) + xmuxClient2.AddRunning() } var closed atomic.Int32 @@ -440,10 +441,10 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me return } if xmuxClient != nil { - xmuxClient.OpenUsage.Add(-1) + xmuxClient.DoneRunning() } if xmuxClient2 != nil && xmuxClient2 != xmuxClient { - xmuxClient2.OpenUsage.Add(-1) + xmuxClient2.DoneRunning() } }, } diff --git a/transport/internet/splithttp/mux.go b/transport/internet/splithttp/mux.go index 093ddd127..4897ace45 100644 --- a/transport/internet/splithttp/mux.go +++ b/transport/internet/splithttp/mux.go @@ -8,6 +8,7 @@ import ( "sync/atomic" "time" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" ) @@ -17,10 +18,27 @@ type XmuxConn interface { type XmuxClient struct { XmuxConn XmuxConn - OpenUsage atomic.Int32 + Running atomic.Int32 leftUsage int32 LeftRequests atomic.Int32 UnreusableAt time.Time + NotUsed atomic.Bool +} + +func (c *XmuxClient) AddRunning() { + c.Running.Add(1) +} + +func (c *XmuxClient) DoneRunning() { + c.Running.Add(-1) + c.maybeClose() +} + +// close the XmuxConn if it is not used and has no running requests +func (c *XmuxClient) maybeClose() { + if c.NotUsed.Load() && c.Running.Load() <= 0 { + common.Close(c.XmuxConn) + } } type XmuxManager struct { @@ -68,10 +86,12 @@ func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient { // when l xmuxClient.LeftRequests.Load() <= 0 || (xmuxClient.UnreusableAt != time.Time{} && time.Now().After(xmuxClient.UnreusableAt)) { errors.LogDebug(ctx, "XMUX: removing xmuxClient, IsClosed() = ", xmuxClient.XmuxConn.IsClosed(), - ", OpenUsage = ", xmuxClient.OpenUsage.Load(), + ", Running = ", xmuxClient.Running.Load(), ", leftUsage = ", xmuxClient.leftUsage, ", LeftRequests = ", xmuxClient.LeftRequests.Load(), ", UnreusableAt = ", xmuxClient.UnreusableAt) + xmuxClient.NotUsed.Store(true) + xmuxClient.maybeClose() m.xmuxClients = append(m.xmuxClients[:i], m.xmuxClients[i+1:]...) } else { i++ @@ -91,7 +111,7 @@ func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient { // when l xmuxClients := make([]*XmuxClient, 0) if m.concurrency > 0 { for _, xmuxClient := range m.xmuxClients { - if xmuxClient.OpenUsage.Load() < m.concurrency { + if xmuxClient.Running.Load() < m.concurrency { xmuxClients = append(xmuxClients, xmuxClient) } } diff --git a/transport/internet/splithttp/mux_test.go b/transport/internet/splithttp/mux_test.go index 835d07f0c..2f5b3520e 100644 --- a/transport/internet/splithttp/mux_test.go +++ b/transport/internet/splithttp/mux_test.go @@ -63,7 +63,7 @@ func TestMaxConcurrency(t *testing.T) { xmuxClients := make(map[interface{}]struct{}) for i := 0; i < 64; i++ { xmuxClient := xmuxManager.GetXmuxClient(context.Background()) - xmuxClient.OpenUsage.Add(1) + xmuxClient.AddRunning() xmuxClients[xmuxClient] = struct{}{} } @@ -82,7 +82,7 @@ func TestDefault(t *testing.T) { xmuxClients := make(map[interface{}]struct{}) for i := 0; i < 64; i++ { xmuxClient := xmuxManager.GetXmuxClient(context.Background()) - xmuxClient.OpenUsage.Add(1) + xmuxClient.AddRunning() xmuxClients[xmuxClient] = struct{}{} }