diff --git a/infra/conf/transport_browser_dialer_test.go b/infra/conf/transport_browser_dialer_test.go deleted file mode 100644 index c5e3a383..00000000 --- a/infra/conf/transport_browser_dialer_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package conf_test - -import ( - "net" - "strings" - "testing" - - . "github.com/xtls/xray-core/infra/conf" -) - -const testBrowserDialerPath = "/123e4567-e89b-12d3-a456-426614174000" - -func TestStreamConfigBuildRejectsBrowserDialerUnsupportedProtocol(t *testing.T) { - network := TransportProtocol("tcp") - config := &StreamConfig{ - Network: &network, - SocketSettings: &SocketConfig{ - BrowserDialer: "127.0.0.1:18080" + testBrowserDialerPath, - }, - } - - _, err := config.Build() - if err == nil || !strings.Contains(err.Error(), "sockopt.browserDialer only supports websocket or splithttp") { - t.Fatalf("expected unsupported protocol error, got: %v", err) - } -} - -func TestStreamConfigBuildRejectsBrowserDialerWithREALITY(t *testing.T) { - network := TransportProtocol("splithttp") - config := &StreamConfig{ - Network: &network, - Security: "reality", - SocketSettings: &SocketConfig{ - BrowserDialer: "127.0.0.1:18081" + testBrowserDialerPath, - }, - } - - _, err := config.Build() - if err == nil || !strings.Contains(err.Error(), "sockopt.browserDialer does not support REALITY") { - t.Fatalf("expected REALITY rejection, got: %v", err) - } -} - -func TestStreamConfigBuildFailsOnBrowserDialerAddressConflict(t *testing.T) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("failed to prepare occupied listener: %v", err) - } - defer listener.Close() - - network := TransportProtocol("websocket") - config := &StreamConfig{ - Network: &network, - SocketSettings: &SocketConfig{ - BrowserDialer: listener.Addr().String() + testBrowserDialerPath, - }, - } - - _, err = config.Build() - if err == nil || !strings.Contains(err.Error(), "Failed to start Browser Dialer listener") { - t.Fatalf("expected address conflict error, got: %v", err) - } -} diff --git a/transport/internet/browser_dialer/dialer.go b/transport/internet/browser_dialer/dialer.go index 7d855114..d39281f9 100644 --- a/transport/internet/browser_dialer/dialer.go +++ b/transport/internet/browser_dialer/dialer.go @@ -47,7 +47,7 @@ var upgrader = &websocket.Upgrader{ } func HasBrowserDialerWithAddress(addr string) bool { - _, ok := parseBrowserDialerAddress(addr) + _, _, ok := parseBrowserDialerAddress(addr) return ok } @@ -56,79 +56,56 @@ type webSocketExtra struct { } type dialerInstance struct { - conns chan *websocket.Conn - pagePath string - page []byte + conns chan *websocket.Conn + page []byte } type dialerServer struct { server *http.Server pageRoutes map[string]*dialerInstance - started bool } -type browserDialerAddress struct { - listenAddr string - path string -} - -func parseBrowserDialerAddress(addr string) (*browserDialerAddress, bool) { +func parseBrowserDialerAddress(addr string) (string, string, bool) { if addr == "" { - return nil, false + return "", "", false } - index := strings.Index(addr, "/") - if index <= 0 { - return nil, false + listenAddr, pathRaw, ok := strings.Cut(addr, "/") + if !ok || listenAddr == "" { + return "", "", false } - listenAddr := addr[:index] - path := strings.TrimSuffix(addr[index:], "/") + path := "/" + strings.TrimSuffix(pathRaw, "/") if path == "" { - return nil, false + return "", "", false } if _, _, err := net.SplitHostPort(listenAddr); err != nil { - return nil, false + return "", "", false } parsedPath, err := url.ParseRequestURI(path) if err != nil || parsedPath.RawQuery != "" || parsedPath.Fragment != "" { - return nil, false + return "", "", false } cleanPath := pathlib.Clean(path) if cleanPath == "." || cleanPath == "/" || cleanPath != path { - return nil, false + return "", "", false } if strings.Count(cleanPath, "/") != 1 { - return nil, false + return "", "", false } id := strings.TrimPrefix(cleanPath, "/") if len(id) != 36 { - return nil, false + return "", "", false } id = strings.ToLower(id) parsedUUID, err := uuid.ParseString(id) if err != nil || parsedUUID.String() != id { - return nil, false + return "", "", false } - cleanPath = "/" + id - - return &browserDialerAddress{ - listenAddr: listenAddr, - path: cleanPath, - }, true + return listenAddr, "/" + id, true } -func newDialerInstance(path string) *dialerInstance { - page := bytes.ReplaceAll(webpage, []byte("dialerPath"), []byte(strings.TrimPrefix(path, "/"))) - dialer := &dialerInstance{ - conns: make(chan *websocket.Conn, 256), - pagePath: path, - page: page, - } - return dialer -} - -func newDialerServer(listenAddr string) *dialerServer { +func newDialerServer(listenAddr string) (*dialerServer, error) { dialer := &dialerServer{ pageRoutes: make(map[string]*dialerInstance), } @@ -170,7 +147,16 @@ func newDialerServer(listenAddr string) *dialerServer { closeConnection(w) }), } - return dialer + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, err + } + go func() { + if err := dialer.server.Serve(listener); err != nil && !stderrors.Is(err, http.ErrServerClosed) { + errors.LogError(context.Background(), "Browser dialer http server unexpected error on ", dialer.server.Addr, ": ", err) + } + }() + return dialer, nil } func closeConnection(w http.ResponseWriter) { @@ -185,47 +171,17 @@ func closeConnection(w http.ResponseWriter) { conn.Close() } -func startDialerServer(dialer *dialerServer) error { - if dialer == nil || dialer.server == nil { - return nil - } - listener, err := net.Listen("tcp", dialer.server.Addr) - if err != nil { - return err - } - go func() { - if err := dialer.server.Serve(listener); err != nil && !stderrors.Is(err, http.ErrServerClosed) { - errors.LogError(context.Background(), "Browser dialer http server unexpected error on ", dialer.server.Addr, ": ", err) - } - }() - return nil -} - -func closeDialerInstance(d *dialerInstance) { - if d == nil { - return - } - for { - select { - case c := <-d.conns: - c.Close() - default: - return - } - } -} - func getDialerByAddress(addr string) (*dialerInstance, error) { - parsed, ok := parseBrowserDialerAddress(addr) + listenAddr, path, ok := parseBrowserDialerAddress(addr) if !ok { return nil, errors.New("invalid sockopt.browserDialer: ", addr) } - key := parsed.listenAddr + parsed.path - var server *dialerServer - var dialer *dialerInstance + key := listenAddr + path mu.Lock() + defer mu.Unlock() + if sockoptDialers == nil { sockoptDialers = make(map[string]*dialerInstance) } @@ -233,38 +189,24 @@ func getDialerByAddress(addr string) (*dialerInstance, error) { dialerServers = make(map[string]*dialerServer) } if dialer, found := sockoptDialers[key]; found { - mu.Unlock() return dialer, nil } - found := false - server, found = dialerServers[parsed.listenAddr] + server, found := dialerServers[listenAddr] if !found { - server = newDialerServer(parsed.listenAddr) - dialerServers[parsed.listenAddr] = server - } - - dialer = newDialerInstance(parsed.path) - sockoptDialers[key] = dialer - server.pageRoutes[dialer.pagePath] = dialer - startServer := !server.started - server.started = true - mu.Unlock() - - if startServer { - if err := startDialerServer(server); err != nil { - mu.Lock() - delete(sockoptDialers, key) - delete(server.pageRoutes, dialer.pagePath) - if len(server.pageRoutes) == 0 { - delete(dialerServers, parsed.listenAddr) - } - mu.Unlock() - closeDialerInstance(dialer) + server, err := newDialerServer(listenAddr) + if err != nil { return nil, err } + dialerServers[listenAddr] = server } + dialer := &dialerInstance{ + conns: make(chan *websocket.Conn, 256), + page: bytes.ReplaceAll(webpage, []byte("dialerPath"), []byte(strings.TrimPrefix(path, "/"))), + } + sockoptDialers[key] = dialer + server.pageRoutes[path] = dialer return dialer, nil } @@ -276,10 +218,6 @@ func EnsureDialerWithAddress(addr string) error { return err } -func DialWS(uri string, ed []byte) (*websocket.Conn, error) { - return DialWSWithAddress("", uri, ed) -} - func DialWSWithAddress(addr string, uri string, ed []byte) (*websocket.Conn, error) { task := task{ Method: "WS", @@ -330,10 +268,6 @@ func httpExtraFromHeadersAndCookies(headers http.Header, cookies []*http.Cookie) return &extra } -func DialGet(uri string, headers http.Header, cookies []*http.Cookie) (*websocket.Conn, error) { - return DialGetWithAddress("", uri, headers, cookies) -} - func DialGetWithAddress(addr string, uri string, headers http.Header, cookies []*http.Cookie) (*websocket.Conn, error) { task := task{ Method: "GET", @@ -345,15 +279,7 @@ func DialGetWithAddress(addr string, uri string, headers http.Header, cookies [] return dialTaskWithAddress(addr, task) } -func DialPacket(method string, uri string, headers http.Header, cookies []*http.Cookie, payload []byte) error { - return DialPacketWithAddress("", method, uri, headers, cookies, payload) -} - func DialPacketWithAddress(addr string, method string, uri string, headers http.Header, cookies []*http.Cookie, payload []byte) error { - return dialWithBody(addr, method, uri, headers, cookies, payload) -} - -func dialWithBody(addr string, method string, uri string, headers http.Header, cookies []*http.Cookie, payload []byte) error { task := task{ Method: method, URL: uri, @@ -380,23 +306,23 @@ func dialWithBody(addr string, method string, uri string, headers http.Header, c return nil } -func dialTask(task task) (*websocket.Conn, error) { - return dialTaskWithAddress("", task) -} - func dialTaskWithAddress(addr string, task task) (*websocket.Conn, error) { data, err := json.Marshal(task) if err != nil { return nil, err } - conns := connsByAddress(addr) - if conns == nil { + if addr == "" { + return nil, errors.New("browser dialer is not configured; set sockopt.browserDialer") + } + dialer, err := getDialerByAddress(addr) + if err != nil || dialer == nil { if addr != "" { return nil, errors.New("browser dialer is not configured for sockopt.browserDialer: ", addr) } return nil, errors.New("browser dialer is not configured; set sockopt.browserDialer") } + conns := dialer.conns var conn *websocket.Conn for { @@ -427,17 +353,6 @@ func CheckOK(conn *websocket.Conn) error { return nil } -func connsByAddress(addr string) chan *websocket.Conn { - if addr == "" { - return nil - } - dialer, err := getDialerByAddress(addr) - if err != nil || dialer == nil { - return nil - } - return dialer.conns -} - func notifyRemovedEnv() { envAddress := platform.NewEnvFlag(platform.BrowserDialerAddress).GetValue(func() string { return "" }) if envAddress == "" { diff --git a/transport/internet/browser_dialer/dialer_test.go b/transport/internet/browser_dialer/dialer_test.go index 506ee2a2..9384306f 100644 --- a/transport/internet/browser_dialer/dialer_test.go +++ b/transport/internet/browser_dialer/dialer_test.go @@ -4,7 +4,7 @@ import "testing" func TestParseBrowserDialerAddressRequireUUIDPath(t *testing.T) { valid := "127.0.0.1:8080/123e4567-e89b-12d3-a456-426614174000" - if _, ok := parseBrowserDialerAddress(valid); !ok { + if _, _, ok := parseBrowserDialerAddress(valid); !ok { t.Fatalf("expected valid browser dialer address: %s", valid) } @@ -15,7 +15,7 @@ func TestParseBrowserDialerAddressRequireUUIDPath(t *testing.T) { "127.0.0.1:8080/123e4567-e89b-12d3-a456-426614174000/extra", } for _, addr := range invalid { - if _, ok := parseBrowserDialerAddress(addr); ok { + if _, _, ok := parseBrowserDialerAddress(addr); ok { t.Fatalf("expected invalid browser dialer address: %s", addr) } }