diff --git a/transport/internet/browser_dialer/dialer.go b/transport/internet/browser_dialer/dialer.go index f094c94d..1784855f 100644 --- a/transport/internet/browser_dialer/dialer.go +++ b/transport/internet/browser_dialer/dialer.go @@ -18,7 +18,6 @@ import ( "github.com/gorilla/websocket" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/platform" - "github.com/xtls/xray-core/common/uuid" ) //go:embed dialer.html @@ -35,6 +34,8 @@ var sockoptDialers map[string]*dialerInstance var dialerServers map[string]*dialerServer var mu sync.RWMutex +const browserDialerSubprotocol = "browser-dialer" + var upgrader = &websocket.Upgrader{ ReadBufferSize: 0, WriteBufferSize: 0, @@ -56,14 +57,12 @@ type webSocketExtra struct { type dialerInstance struct { conns chan *websocket.Conn pagePath string - wsPath string page []byte } type dialerServer struct { server *http.Server pageRoutes map[string]*dialerInstance - wsRoutes map[string]*dialerInstance } type browserDialerAddress struct { @@ -105,16 +104,10 @@ func parseBrowserDialerAddress(addr string) (*browserDialerAddress, bool) { } func newDialerInstance(path string) *dialerInstance { - token := uuid.New() - csrfToken := token.String() - escapedCsrfToken := url.PathEscape(csrfToken) - wsPath := path + "/" + escapedCsrfToken page := bytes.ReplaceAll(webpage, []byte("dialerPath"), []byte(strings.TrimPrefix(path, "/"))) - page = bytes.ReplaceAll(page, []byte("csrfToken"), []byte(escapedCsrfToken)) dialer := &dialerInstance{ conns: make(chan *websocket.Conn, 256), pagePath: path, - wsPath: wsPath, page: page, } return dialer @@ -123,19 +116,28 @@ func newDialerInstance(path string) *dialerInstance { func newDialerServer(listenAddr string) *dialerServer { dialer := &dialerServer{ pageRoutes: make(map[string]*dialerInstance), - wsRoutes: make(map[string]*dialerInstance), } dialer.server = &http.Server{ Addr: listenAddr, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.RLock() - wsDialer := dialer.wsRoutes[r.URL.Path] pageDialer := dialer.pageRoutes[r.URL.Path] mu.RUnlock() - if wsDialer != nil { - if conn, err := upgrader.Upgrade(w, r, nil); err == nil { - wsDialer.conns <- conn + if pageDialer != nil && websocket.IsWebSocketUpgrade(r) { + ok := false + for _, protocol := range websocket.Subprotocols(r) { + if protocol == browserDialerSubprotocol { + ok = true + break + } + } + if !ok { + closeConnection(w) + return + } + if conn, err := upgrader.Upgrade(w, r, http.Header{"Sec-WebSocket-Protocol": []string{browserDialerSubprotocol}}); err == nil { + pageDialer.conns <- conn } else { errors.LogError(context.Background(), "Browser dialer http upgrade unexpected error: ", err) } @@ -224,7 +226,6 @@ func getDialerByAddress(addr string) *dialerInstance { dialer := newDialerInstance(parsed.path) sockoptDialers[key] = dialer server.pageRoutes[dialer.pagePath] = dialer - server.wsRoutes[dialer.wsPath] = dialer mu.Unlock() if startServer { diff --git a/transport/internet/browser_dialer/dialer.html b/transport/internet/browser_dialer/dialer.html index 379dc52b..acef530b 100644 --- a/transport/internet/browser_dialer/dialer.html +++ b/transport/internet/browser_dialer/dialer.html @@ -10,7 +10,7 @@ // Enable a much more aggressive JIT for performance gains // Copyright (c) 2021 XRAY. Mozilla Public License 2.0. - let url = "ws://" + window.location.host + "/dialerPath/csrfToken"; + let url = "ws://" + window.location.host + "/dialerPath"; let clientIdleCount = 0; let upstreamGetCount = 0; let upstreamWsCount = 0; @@ -67,7 +67,7 @@ } clientIdleCount += 1; console.log("Prepare", url); - let ws = new WebSocket(url); + let ws = new WebSocket(url, "browser-dialer"); // arraybuffer is significantly faster in chrome than default // blob, tested with chrome 123 ws.binaryType = "arraybuffer";