refactor: simplify browser dialer and remove added conf tests

Agent-Logs-Url: https://github.com/XTLS/Xray-core/sessions/3aee4c73-7847-433c-905a-2eafe5b1bfe8

Co-authored-by: RPRX <63339210+RPRX@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-04-26 16:23:37 +00:00
committed by GitHub
parent 1d4250e6f0
commit 5afd664c8b
3 changed files with 50 additions and 198 deletions
@@ -1,63 +0,0 @@
package conf_test
import (
"net"
"strings"
"testing"
. "github.com/xtls/xray-core/infra/conf"
)
const testBrowserDialerPath = "/123e4567-e89b-12d3-a456-426614174000"
func TestStreamConfigBuildRejectsBrowserDialerUnsupportedProtocol(t *testing.T) {
network := TransportProtocol("tcp")
config := &StreamConfig{
Network: &network,
SocketSettings: &SocketConfig{
BrowserDialer: "127.0.0.1:18080" + testBrowserDialerPath,
},
}
_, err := config.Build()
if err == nil || !strings.Contains(err.Error(), "sockopt.browserDialer only supports websocket or splithttp") {
t.Fatalf("expected unsupported protocol error, got: %v", err)
}
}
func TestStreamConfigBuildRejectsBrowserDialerWithREALITY(t *testing.T) {
network := TransportProtocol("splithttp")
config := &StreamConfig{
Network: &network,
Security: "reality",
SocketSettings: &SocketConfig{
BrowserDialer: "127.0.0.1:18081" + testBrowserDialerPath,
},
}
_, err := config.Build()
if err == nil || !strings.Contains(err.Error(), "sockopt.browserDialer does not support REALITY") {
t.Fatalf("expected REALITY rejection, got: %v", err)
}
}
func TestStreamConfigBuildFailsOnBrowserDialerAddressConflict(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to prepare occupied listener: %v", err)
}
defer listener.Close()
network := TransportProtocol("websocket")
config := &StreamConfig{
Network: &network,
SocketSettings: &SocketConfig{
BrowserDialer: listener.Addr().String() + testBrowserDialerPath,
},
}
_, err = config.Build()
if err == nil || !strings.Contains(err.Error(), "Failed to start Browser Dialer listener") {
t.Fatalf("expected address conflict error, got: %v", err)
}
}
+48 -133
View File
@@ -47,7 +47,7 @@ var upgrader = &websocket.Upgrader{
} }
func HasBrowserDialerWithAddress(addr string) bool { func HasBrowserDialerWithAddress(addr string) bool {
_, ok := parseBrowserDialerAddress(addr) _, _, ok := parseBrowserDialerAddress(addr)
return ok return ok
} }
@@ -56,79 +56,56 @@ type webSocketExtra struct {
} }
type dialerInstance struct { type dialerInstance struct {
conns chan *websocket.Conn conns chan *websocket.Conn
pagePath string page []byte
page []byte
} }
type dialerServer struct { type dialerServer struct {
server *http.Server server *http.Server
pageRoutes map[string]*dialerInstance pageRoutes map[string]*dialerInstance
started bool
} }
type browserDialerAddress struct { func parseBrowserDialerAddress(addr string) (string, string, bool) {
listenAddr string
path string
}
func parseBrowserDialerAddress(addr string) (*browserDialerAddress, bool) {
if addr == "" { if addr == "" {
return nil, false return "", "", false
} }
index := strings.Index(addr, "/") listenAddr, pathRaw, ok := strings.Cut(addr, "/")
if index <= 0 { if !ok || listenAddr == "" {
return nil, false return "", "", false
} }
listenAddr := addr[:index] path := "/" + strings.TrimSuffix(pathRaw, "/")
path := strings.TrimSuffix(addr[index:], "/")
if path == "" { if path == "" {
return nil, false return "", "", false
} }
if _, _, err := net.SplitHostPort(listenAddr); err != nil { if _, _, err := net.SplitHostPort(listenAddr); err != nil {
return nil, false return "", "", false
} }
parsedPath, err := url.ParseRequestURI(path) parsedPath, err := url.ParseRequestURI(path)
if err != nil || parsedPath.RawQuery != "" || parsedPath.Fragment != "" { if err != nil || parsedPath.RawQuery != "" || parsedPath.Fragment != "" {
return nil, false return "", "", false
} }
cleanPath := pathlib.Clean(path) cleanPath := pathlib.Clean(path)
if cleanPath == "." || cleanPath == "/" || cleanPath != path { if cleanPath == "." || cleanPath == "/" || cleanPath != path {
return nil, false return "", "", false
} }
if strings.Count(cleanPath, "/") != 1 { if strings.Count(cleanPath, "/") != 1 {
return nil, false return "", "", false
} }
id := strings.TrimPrefix(cleanPath, "/") id := strings.TrimPrefix(cleanPath, "/")
if len(id) != 36 { if len(id) != 36 {
return nil, false return "", "", false
} }
id = strings.ToLower(id) id = strings.ToLower(id)
parsedUUID, err := uuid.ParseString(id) parsedUUID, err := uuid.ParseString(id)
if err != nil || parsedUUID.String() != id { if err != nil || parsedUUID.String() != id {
return nil, false return "", "", false
} }
cleanPath = "/" + id return listenAddr, "/" + id, true
return &browserDialerAddress{
listenAddr: listenAddr,
path: cleanPath,
}, true
} }
func newDialerInstance(path string) *dialerInstance { func newDialerServer(listenAddr string) (*dialerServer, error) {
page := bytes.ReplaceAll(webpage, []byte("dialerPath"), []byte(strings.TrimPrefix(path, "/")))
dialer := &dialerInstance{
conns: make(chan *websocket.Conn, 256),
pagePath: path,
page: page,
}
return dialer
}
func newDialerServer(listenAddr string) *dialerServer {
dialer := &dialerServer{ dialer := &dialerServer{
pageRoutes: make(map[string]*dialerInstance), pageRoutes: make(map[string]*dialerInstance),
} }
@@ -170,7 +147,16 @@ func newDialerServer(listenAddr string) *dialerServer {
closeConnection(w) closeConnection(w)
}), }),
} }
return dialer listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return nil, err
}
go func() {
if err := dialer.server.Serve(listener); err != nil && !stderrors.Is(err, http.ErrServerClosed) {
errors.LogError(context.Background(), "Browser dialer http server unexpected error on ", dialer.server.Addr, ": ", err)
}
}()
return dialer, nil
} }
func closeConnection(w http.ResponseWriter) { func closeConnection(w http.ResponseWriter) {
@@ -185,47 +171,17 @@ func closeConnection(w http.ResponseWriter) {
conn.Close() conn.Close()
} }
func startDialerServer(dialer *dialerServer) error {
if dialer == nil || dialer.server == nil {
return nil
}
listener, err := net.Listen("tcp", dialer.server.Addr)
if err != nil {
return err
}
go func() {
if err := dialer.server.Serve(listener); err != nil && !stderrors.Is(err, http.ErrServerClosed) {
errors.LogError(context.Background(), "Browser dialer http server unexpected error on ", dialer.server.Addr, ": ", err)
}
}()
return nil
}
func closeDialerInstance(d *dialerInstance) {
if d == nil {
return
}
for {
select {
case c := <-d.conns:
c.Close()
default:
return
}
}
}
func getDialerByAddress(addr string) (*dialerInstance, error) { func getDialerByAddress(addr string) (*dialerInstance, error) {
parsed, ok := parseBrowserDialerAddress(addr) listenAddr, path, ok := parseBrowserDialerAddress(addr)
if !ok { if !ok {
return nil, errors.New("invalid sockopt.browserDialer: ", addr) return nil, errors.New("invalid sockopt.browserDialer: ", addr)
} }
key := parsed.listenAddr + parsed.path key := listenAddr + path
var server *dialerServer
var dialer *dialerInstance
mu.Lock() mu.Lock()
defer mu.Unlock()
if sockoptDialers == nil { if sockoptDialers == nil {
sockoptDialers = make(map[string]*dialerInstance) sockoptDialers = make(map[string]*dialerInstance)
} }
@@ -233,38 +189,24 @@ func getDialerByAddress(addr string) (*dialerInstance, error) {
dialerServers = make(map[string]*dialerServer) dialerServers = make(map[string]*dialerServer)
} }
if dialer, found := sockoptDialers[key]; found { if dialer, found := sockoptDialers[key]; found {
mu.Unlock()
return dialer, nil return dialer, nil
} }
found := false server, found := dialerServers[listenAddr]
server, found = dialerServers[parsed.listenAddr]
if !found { if !found {
server = newDialerServer(parsed.listenAddr) server, err := newDialerServer(listenAddr)
dialerServers[parsed.listenAddr] = server if err != nil {
}
dialer = newDialerInstance(parsed.path)
sockoptDialers[key] = dialer
server.pageRoutes[dialer.pagePath] = dialer
startServer := !server.started
server.started = true
mu.Unlock()
if startServer {
if err := startDialerServer(server); err != nil {
mu.Lock()
delete(sockoptDialers, key)
delete(server.pageRoutes, dialer.pagePath)
if len(server.pageRoutes) == 0 {
delete(dialerServers, parsed.listenAddr)
}
mu.Unlock()
closeDialerInstance(dialer)
return nil, err return nil, err
} }
dialerServers[listenAddr] = server
} }
dialer := &dialerInstance{
conns: make(chan *websocket.Conn, 256),
page: bytes.ReplaceAll(webpage, []byte("dialerPath"), []byte(strings.TrimPrefix(path, "/"))),
}
sockoptDialers[key] = dialer
server.pageRoutes[path] = dialer
return dialer, nil return dialer, nil
} }
@@ -276,10 +218,6 @@ func EnsureDialerWithAddress(addr string) error {
return err return err
} }
func DialWS(uri string, ed []byte) (*websocket.Conn, error) {
return DialWSWithAddress("", uri, ed)
}
func DialWSWithAddress(addr string, uri string, ed []byte) (*websocket.Conn, error) { func DialWSWithAddress(addr string, uri string, ed []byte) (*websocket.Conn, error) {
task := task{ task := task{
Method: "WS", Method: "WS",
@@ -330,10 +268,6 @@ func httpExtraFromHeadersAndCookies(headers http.Header, cookies []*http.Cookie)
return &extra return &extra
} }
func DialGet(uri string, headers http.Header, cookies []*http.Cookie) (*websocket.Conn, error) {
return DialGetWithAddress("", uri, headers, cookies)
}
func DialGetWithAddress(addr string, uri string, headers http.Header, cookies []*http.Cookie) (*websocket.Conn, error) { func DialGetWithAddress(addr string, uri string, headers http.Header, cookies []*http.Cookie) (*websocket.Conn, error) {
task := task{ task := task{
Method: "GET", Method: "GET",
@@ -345,15 +279,7 @@ func DialGetWithAddress(addr string, uri string, headers http.Header, cookies []
return dialTaskWithAddress(addr, task) return dialTaskWithAddress(addr, task)
} }
func DialPacket(method string, uri string, headers http.Header, cookies []*http.Cookie, payload []byte) error {
return DialPacketWithAddress("", method, uri, headers, cookies, payload)
}
func DialPacketWithAddress(addr string, method string, uri string, headers http.Header, cookies []*http.Cookie, payload []byte) error { func DialPacketWithAddress(addr string, method string, uri string, headers http.Header, cookies []*http.Cookie, payload []byte) error {
return dialWithBody(addr, method, uri, headers, cookies, payload)
}
func dialWithBody(addr string, method string, uri string, headers http.Header, cookies []*http.Cookie, payload []byte) error {
task := task{ task := task{
Method: method, Method: method,
URL: uri, URL: uri,
@@ -380,23 +306,23 @@ func dialWithBody(addr string, method string, uri string, headers http.Header, c
return nil return nil
} }
func dialTask(task task) (*websocket.Conn, error) {
return dialTaskWithAddress("", task)
}
func dialTaskWithAddress(addr string, task task) (*websocket.Conn, error) { func dialTaskWithAddress(addr string, task task) (*websocket.Conn, error) {
data, err := json.Marshal(task) data, err := json.Marshal(task)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conns := connsByAddress(addr) if addr == "" {
if conns == nil { return nil, errors.New("browser dialer is not configured; set sockopt.browserDialer")
}
dialer, err := getDialerByAddress(addr)
if err != nil || dialer == nil {
if addr != "" { if addr != "" {
return nil, errors.New("browser dialer is not configured for sockopt.browserDialer: ", addr) return nil, errors.New("browser dialer is not configured for sockopt.browserDialer: ", addr)
} }
return nil, errors.New("browser dialer is not configured; set sockopt.browserDialer") return nil, errors.New("browser dialer is not configured; set sockopt.browserDialer")
} }
conns := dialer.conns
var conn *websocket.Conn var conn *websocket.Conn
for { for {
@@ -427,17 +353,6 @@ func CheckOK(conn *websocket.Conn) error {
return nil return nil
} }
func connsByAddress(addr string) chan *websocket.Conn {
if addr == "" {
return nil
}
dialer, err := getDialerByAddress(addr)
if err != nil || dialer == nil {
return nil
}
return dialer.conns
}
func notifyRemovedEnv() { func notifyRemovedEnv() {
envAddress := platform.NewEnvFlag(platform.BrowserDialerAddress).GetValue(func() string { return "" }) envAddress := platform.NewEnvFlag(platform.BrowserDialerAddress).GetValue(func() string { return "" })
if envAddress == "" { if envAddress == "" {
@@ -4,7 +4,7 @@ import "testing"
func TestParseBrowserDialerAddressRequireUUIDPath(t *testing.T) { func TestParseBrowserDialerAddressRequireUUIDPath(t *testing.T) {
valid := "127.0.0.1:8080/123e4567-e89b-12d3-a456-426614174000" valid := "127.0.0.1:8080/123e4567-e89b-12d3-a456-426614174000"
if _, ok := parseBrowserDialerAddress(valid); !ok { if _, _, ok := parseBrowserDialerAddress(valid); !ok {
t.Fatalf("expected valid browser dialer address: %s", valid) t.Fatalf("expected valid browser dialer address: %s", valid)
} }
@@ -15,7 +15,7 @@ func TestParseBrowserDialerAddressRequireUUIDPath(t *testing.T) {
"127.0.0.1:8080/123e4567-e89b-12d3-a456-426614174000/extra", "127.0.0.1:8080/123e4567-e89b-12d3-a456-426614174000/extra",
} }
for _, addr := range invalid { for _, addr := range invalid {
if _, ok := parseBrowserDialerAddress(addr); ok { if _, _, ok := parseBrowserDialerAddress(addr); ok {
t.Fatalf("expected invalid browser dialer address: %s", addr) t.Fatalf("expected invalid browser dialer address: %s", addr)
} }
} }