From dda2b10c9d18f2329d65f4a2effc09fed6ac9cd1 Mon Sep 17 00:00:00 2001 From: yiguodev <147401898+yiguodev@users.noreply.github.com> Date: Wed, 24 Jun 2026 20:00:22 +0800 Subject: [PATCH] TUN inbound: Add traffic counters; Metrics: Rely on instance (#6349) https://github.com/XTLS/Xray-core/pull/6349#issuecomment-4775121300 --- app/metrics/metrics.go | 210 ++++++++++++++++++++++++++---------- app/metrics/metrics_test.go | 161 +++++++++++++++++++++++++++ proxy/tun/handler.go | 28 +++++ proxy/tun/handler_test.go | 133 +++++++++++++++++++++++ 4 files changed, 473 insertions(+), 59 deletions(-) create mode 100644 app/metrics/metrics_test.go create mode 100644 proxy/tun/handler_test.go diff --git a/app/metrics/metrics.go b/app/metrics/metrics.go index a2ae2f30c..6d1cff721 100644 --- a/app/metrics/metrics.go +++ b/app/metrics/metrics.go @@ -2,15 +2,18 @@ package metrics import ( "context" + "encoding/json" + stderrors "errors" "expvar" + stdnet "net" "net/http" - _ "net/http/pprof" + "net/http/pprof" "strings" "github.com/xtls/xray-core/app/observatory" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/net" + xnet "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -21,15 +24,17 @@ import ( type MetricsHandler struct { ohm outbound.Manager statsManager feature_stats.Manager - observatory extension.Observatory + ctx context.Context tag string listen string - tcpListener net.Listener + tcpListener xnet.Listener + listener *OutboundListener } // NewMetricsHandler creates a new MetricsHandler based on the given config. func NewMetricsHandler(ctx context.Context, config *Config) (*MetricsHandler, error) { c := &MetricsHandler{ + ctx: ctx, tag: config.Tag, listen: config.Listen, } @@ -37,46 +42,6 @@ func NewMetricsHandler(ctx context.Context, config *Config) (*MetricsHandler, er c.statsManager = sm c.ohm = om })) - expvar.Publish("stats", expvar.Func(func() interface{} { - resp := map[string]map[string]map[string]int64{ - "inbound": {}, - "outbound": {}, - "user": {}, - } - c.statsManager.VisitCounters(func(name string, counter feature_stats.Counter) bool { - nameSplit := strings.Split(name, ">>>") - typeName, tagOrUser, direction := nameSplit[0], nameSplit[1], nameSplit[3] - if item, found := resp[typeName][tagOrUser]; found { - item[direction] = counter.Value() - } else { - resp[typeName][tagOrUser] = map[string]int64{ - direction: counter.Value(), - } - } - return true - }) - return resp - })) - expvar.Publish("observatory", expvar.Func(func() interface{} { - if c.observatory == nil { - common.Must(core.RequireFeatures(ctx, func(observatory extension.Observatory) error { - c.observatory = observatory - return nil - })) - if c.observatory == nil { - return nil - } - } - resp := map[string]*observatory.OutboundStatus{} - if o, err := c.observatory.GetObservation(context.Background()); err != nil { - return err - } else { - for _, x := range o.(*observatory.ObservationResult).GetStatus() { - resp[x.OutboundTag] = x - } - } - return resp - })) return c, nil } @@ -85,45 +50,172 @@ func (p *MetricsHandler) Type() interface{} { } func (p *MetricsHandler) Start() error { + handler := p.httpHandler() + // direct listen a port if listen is set if p.listen != "" { - TCPlistener, err := net.Listen("tcp", p.listen) + TCPlistener, err := xnet.Listen("tcp", p.listen) if err != nil { return err } p.tcpListener = TCPlistener errors.LogInfo(context.Background(), "Metrics server listening on ", p.listen) - go func() { - if err := http.Serve(TCPlistener, http.DefaultServeMux); err != nil { - errors.LogErrorInner(context.Background(), err, "failed to start metrics server") - } - }() + go p.serve(TCPlistener, handler) + } + + if p.tag == "" { + if p.tcpListener == nil { + return errors.New("metrics must have a tag or listen address") + } + return nil } listener := &OutboundListener{ - buffer: make(chan net.Conn, 4), + buffer: make(chan xnet.Conn, 4), done: done.New(), } + p.listener = listener - go func() { - if err := http.Serve(listener, http.DefaultServeMux); err != nil { - errors.LogErrorInner(context.Background(), err, "failed to start metrics server") - } - }() + go p.serve(listener, handler) if err := p.ohm.RemoveHandler(context.Background(), p.tag); err != nil { errors.LogInfo(context.Background(), "failed to remove existing handler") } - return p.ohm.AddHandler(context.Background(), &Outbound{ + if err := p.ohm.AddHandler(context.Background(), &Outbound{ tag: p.tag, listener: listener, - }) + }); err != nil { + if closeErr := p.Close(); closeErr != nil { + errors.LogErrorInner(context.Background(), closeErr, "failed to close metrics server after start failure") + } + return err + } + + return nil } func (p *MetricsHandler) Close() error { - return nil + var errs []error + if p.tcpListener != nil { + errs = append(errs, p.tcpListener.Close()) + p.tcpListener = nil + } + if p.listener != nil { + errs = append(errs, p.listener.Close()) + p.listener = nil + } + if p.ohm != nil && p.tag != "" { + if err := p.ohm.RemoveHandler(context.Background(), p.tag); err != nil { + errors.LogInfo(context.Background(), "failed to remove metrics handler") + } + } + return errors.Combine(errs...) +} + +func (p *MetricsHandler) serve(listener xnet.Listener, handler http.Handler) { + if err := http.Serve(listener, handler); err != nil && !isClosedListenerError(err) { + errors.LogErrorInner(context.Background(), err, "failed to start metrics server") + } +} + +func isClosedListenerError(err error) bool { + if err == nil { + return true + } + if stderrors.Is(err, stdnet.ErrClosed) || stderrors.Is(err, http.ErrServerClosed) { + return true + } + errText := err.Error() + return strings.Contains(errText, "listen closed") || + strings.Contains(errText, "use of closed network connection") +} + +func (p *MetricsHandler) httpHandler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/debug/vars", p.handleDebugVars) + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + return mux +} + +func (p *MetricsHandler) handleDebugVars(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + vars := map[string]json.RawMessage{} + expvar.Do(func(kv expvar.KeyValue) { + value := json.RawMessage(kv.Value.String()) + if !json.Valid(value) { + value = json.RawMessage("null") + } + vars[kv.Key] = value + }) + vars["stats"] = marshalJSON(p.stats()) + vars["observatory"] = marshalJSON(p.observatoryStatus()) + + payload, err := json.Marshal(vars) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(payload) +} + +func marshalJSON(value interface{}) json.RawMessage { + data, err := json.Marshal(value) + if err != nil { + return json.RawMessage("null") + } + return data +} + +func (p *MetricsHandler) stats() map[string]map[string]map[string]int64 { + resp := map[string]map[string]map[string]int64{ + "inbound": {}, + "outbound": {}, + "user": {}, + } + p.statsManager.VisitCounters(func(name string, counter feature_stats.Counter) bool { + nameSplit := strings.Split(name, ">>>") + if len(nameSplit) < 4 { + return true + } + typeName, tagOrUser, direction := nameSplit[0], nameSplit[1], nameSplit[3] + items, found := resp[typeName] + if !found { + items = map[string]map[string]int64{} + resp[typeName] = items + } + if item, found := items[tagOrUser]; found { + item[direction] = counter.Value() + } else { + items[tagOrUser] = map[string]int64{ + direction: counter.Value(), + } + } + return true + }) + return resp +} + +func (p *MetricsHandler) observatoryStatus() interface{} { + feature := core.MustFromContext(p.ctx).GetFeature(extension.ObservatoryType()) + if feature == nil { + return nil + } + observatoryFeature := feature.(extension.Observatory) + resp := map[string]*observatory.OutboundStatus{} + if o, err := observatoryFeature.GetObservation(context.Background()); err != nil { + return err + } else { + for _, x := range o.(*observatory.ObservationResult).GetStatus() { + resp[x.OutboundTag] = x + } + } + return resp } func init() { diff --git a/app/metrics/metrics_test.go b/app/metrics/metrics_test.go new file mode 100644 index 000000000..a493887f8 --- /dev/null +++ b/app/metrics/metrics_test.go @@ -0,0 +1,161 @@ +package metrics + +import ( + "context" + "encoding/json" + stdnet "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/xtls/xray-core/app/dispatcher" + "github.com/xtls/xray-core/app/proxyman" + _ "github.com/xtls/xray-core/app/proxyman/inbound" + _ "github.com/xtls/xray-core/app/proxyman/outbound" + appstats "github.com/xtls/xray-core/app/stats" + "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/core" + feature_outbound "github.com/xtls/xray-core/features/outbound" +) + +func TestMetricsCanRestartInSameProcess(t *testing.T) { + for i := 0; i < 2; i++ { + server := startMetricsTestServer(t) + readMetricsVars(t, server) + readMetricsPprof(t, server) + if err := server.Close(); err != nil { + t.Fatalf("failed to close metrics server: %v", err) + } + } +} + +func TestMetricsCanRunMultipleInstancesInSameProcess(t *testing.T) { + server1 := startMetricsTestServer(t) + t.Cleanup(func() { + _ = server1.Close() + }) + server2 := startMetricsTestServer(t) + t.Cleanup(func() { + _ = server2.Close() + }) + + readMetricsVars(t, server1) + readMetricsVars(t, server2) +} + +func TestMetricsListenOnlyWithoutTagDoesNotRegisterOutbound(t *testing.T) { + listen := pickMetricsListenAddress(t) + server := startMetricsTestServerWithMetricsConfig(t, &Config{ + Listen: listen, + }) + t.Cleanup(func() { + _ = server.Close() + }) + + response, err := http.Get("http://" + listen + "/debug/vars") + if err != nil { + t.Fatalf("failed to read listen-only metrics: %v", err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + t.Fatalf("unexpected listen-only metrics status: %d", response.StatusCode) + } + + outboundManager := server.GetFeature(feature_outbound.ManagerType()).(feature_outbound.Manager) + if handlers := outboundManager.ListHandlers(context.Background()); len(handlers) != 0 { + t.Fatalf("listen-only metrics registered outbound handlers: got %d, want 0", len(handlers)) + } +} + +func startMetricsTestServer(t *testing.T) *core.Instance { + return startMetricsTestServerWithMetricsConfig(t, &Config{ + Tag: "metrics_out", + }) +} + +func startMetricsTestServerWithMetricsConfig(t *testing.T, metricsConfig *Config) *core.Instance { + t.Helper() + + server, err := core.New(metricsTestConfig(metricsConfig)) + if err != nil { + t.Fatalf("failed to create metrics server: %v", err) + } + if err := server.Start(); err != nil { + _ = server.Close() + t.Fatalf("failed to start metrics server: %v", err) + } + return server +} + +func metricsTestConfig(metricsConfig *Config) *core.Config { + return &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&dispatcher.Config{}), + serial.ToTypedMessage(&proxyman.InboundConfig{}), + serial.ToTypedMessage(&proxyman.OutboundConfig{}), + serial.ToTypedMessage(&appstats.Config{}), + serial.ToTypedMessage(metricsConfig), + }, + } +} + +func pickMetricsListenAddress(t *testing.T) string { + t.Helper() + + listener, err := stdnet.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to pick metrics listen address: %v", err) + } + defer listener.Close() + return listener.Addr().String() +} + +func readMetricsVars(t *testing.T, server *core.Instance) { + t.Helper() + + recorder := httptest.NewRecorder() + metricsHandler(t, server).httpHandler().ServeHTTP( + recorder, + httptest.NewRequest(http.MethodGet, "/debug/vars", nil), + ) + + if recorder.Code != http.StatusOK { + t.Fatalf("unexpected metrics vars status: %d", recorder.Code) + } + + var payload map[string]interface{} + if err := json.NewDecoder(recorder.Body).Decode(&payload); err != nil { + t.Fatalf("failed to decode metrics vars: %v", err) + } + if _, found := payload["stats"]; !found { + t.Fatal("metrics vars missing stats") + } + if _, found := payload["observatory"]; !found { + t.Fatal("metrics vars missing observatory") + } +} + +func readMetricsPprof(t *testing.T, server *core.Instance) { + t.Helper() + + recorder := httptest.NewRecorder() + metricsHandler(t, server).httpHandler().ServeHTTP( + recorder, + httptest.NewRequest(http.MethodGet, "/debug/pprof/goroutine?debug=1", nil), + ) + + if recorder.Code != http.StatusOK { + t.Fatalf("unexpected metrics pprof status: %d", recorder.Code) + } +} + +func metricsHandler(t *testing.T, server *core.Instance) *MetricsHandler { + t.Helper() + + feature := server.GetFeature((*MetricsHandler)(nil)) + handler, ok := feature.(*MetricsHandler) + if !ok || handler == nil { + t.Fatal("metrics handler not registered") + } + return handler +} diff --git a/proxy/tun/handler.go b/proxy/tun/handler.go index 732468f76..5bf8f231f 100644 --- a/proxy/tun/handler.go +++ b/proxy/tun/handler.go @@ -17,6 +17,7 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/stat" @@ -32,6 +33,8 @@ type Handler struct { dispatcher routing.Dispatcher tag string sniffingRequest session.SniffingRequest + uplinkCounter stats.Counter + downlinkCounter stats.Counter } // ConnectionHandler interface with the only method that stack is going to push new connections to @@ -59,6 +62,23 @@ func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routin t.policyManager = pm t.dispatcher = dispatcher + if len(t.tag) > 0 && pm.ForSystem().Stats.InboundUplink { + statsManager := core.MustFromContext(ctx).GetFeature(stats.ManagerType()).(stats.Manager) + name := "inbound>>>" + t.tag + ">>>traffic>>>uplink" + c, _ := stats.GetOrRegisterCounter(statsManager, name) + if c != nil { + t.uplinkCounter = c + } + } + if len(t.tag) > 0 && pm.ForSystem().Stats.InboundDownlink { + statsManager := core.MustFromContext(ctx).GetFeature(stats.ManagerType()).(stats.Manager) + name := "inbound>>>" + t.tag + ">>>traffic>>>downlink" + c, _ := stats.GetOrRegisterCounter(statsManager, name) + if c != nil { + t.downlinkCounter = c + } + } + return nil } @@ -151,6 +171,14 @@ func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) { return } source := net.DestinationFromAddr(remote) + if t.uplinkCounter != nil || t.downlinkCounter != nil { + conn = &stat.CounterConnection{ + Connection: conn, + ReadCounter: t.uplinkCounter, + WriteCounter: t.downlinkCounter, + } + } + inbound := session.Inbound{ Name: "tun", Tag: t.tag, diff --git a/proxy/tun/handler_test.go b/proxy/tun/handler_test.go new file mode 100644 index 000000000..25e9117ab --- /dev/null +++ b/proxy/tun/handler_test.go @@ -0,0 +1,133 @@ +package tun + +import ( + "bytes" + "context" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/xtls/xray-core/common/buf" + xnet "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/transport" +) + +type testCounter struct { + value int64 +} + +func (c *testCounter) Value() int64 { + return atomic.LoadInt64(&c.value) +} + +func (c *testCounter) Set(value int64) int64 { + return atomic.SwapInt64(&c.value, value) +} + +func (c *testCounter) Add(value int64) int64 { + return atomic.AddInt64(&c.value, value) - value +} + +type testConn struct { + reader *bytes.Reader + writer bytes.Buffer +} + +func newTestConn(input []byte) *testConn { + return &testConn{reader: bytes.NewReader(input)} +} + +func (c *testConn) Read(payload []byte) (int, error) { + return c.reader.Read(payload) +} + +func (c *testConn) Write(payload []byte) (int, error) { + return c.writer.Write(payload) +} + +func (c *testConn) Close() error { + return nil +} + +func (c *testConn) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 1080} +} + +func (c *testConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(10, 0, 0, 2), Port: 12345} +} + +func (c *testConn) SetDeadline(time.Time) error { + return nil +} + +func (c *testConn) SetReadDeadline(time.Time) error { + return nil +} + +func (c *testConn) SetWriteDeadline(time.Time) error { + return nil +} + +type testDispatcher struct { + writePayload []byte + readBytes int32 +} + +func (d *testDispatcher) Type() interface{} { + return routing.DispatcherType() +} + +func (d *testDispatcher) Start() error { + return nil +} + +func (d *testDispatcher) Close() error { + return nil +} + +func (d *testDispatcher) Dispatch(context.Context, xnet.Destination) (*transport.Link, error) { + return nil, nil +} + +func (d *testDispatcher) DispatchLink(ctx context.Context, dest xnet.Destination, link *transport.Link) error { + mb, err := link.Reader.ReadMultiBuffer() + if err != nil { + return err + } + atomic.StoreInt32(&d.readBytes, mb.Len()) + buf.ReleaseMulti(mb) + + return link.Writer.WriteMultiBuffer(buf.MultiBuffer{buf.FromBytes(d.writePayload)}) +} + +func TestHandlerCountsTunConnectionTraffic(t *testing.T) { + uplinkCounter := new(testCounter) + downlinkCounter := new(testCounter) + dispatcher := &testDispatcher{writePayload: []byte("downlink")} + conn := newTestConn([]byte("uplink")) + + handler := &Handler{ + ctx: context.Background(), + config: &Config{}, + dispatcher: dispatcher, + uplinkCounter: uplinkCounter, + downlinkCounter: downlinkCounter, + } + handler.HandleConnection(conn, xnet.TCPDestination(xnet.LocalHostIP, 443)) + + if got := uplinkCounter.Value(); got != int64(len("uplink")) { + t.Fatalf("unexpected uplink counter: got %d, want %d", got, len("uplink")) + } + if got := downlinkCounter.Value(); got != int64(len("downlink")) { + t.Fatalf("unexpected downlink counter: got %d, want %d", got, len("downlink")) + } + if got := int(atomic.LoadInt32(&dispatcher.readBytes)); got != len("uplink") { + t.Fatalf("dispatcher read unexpected bytes: got %d, want %d", got, len("uplink")) + } + if got := conn.writer.String(); got != "downlink" { + t.Fatalf("connection write mismatch: got %q, want %q", got, "downlink") + } +}