diff --git a/transport/internet/browser_dialer/dialer.go b/transport/internet/browser_dialer/dialer.go index be2f137e..53955bc4 100644 --- a/transport/internet/browser_dialer/dialer.go +++ b/transport/internet/browser_dialer/dialer.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "encoding/json" "net/http" + "sync" "time" "github.com/gorilla/websocket" @@ -26,6 +27,8 @@ type task struct { } var conns chan *websocket.Conn +var server *http.Server +var mu sync.Mutex var upgrader = &websocket.Upgrader{ ReadBufferSize: 0, @@ -36,27 +39,48 @@ var upgrader = &websocket.Upgrader{ }, } -func init() { +// Used by external projects when using xray as a go module +func Reload() { addr := platform.NewEnvFlag(platform.BrowserDialerAddress).GetValue(func() string { return "" }) + mu.Lock() + defer mu.Unlock() + + if server != nil { + server.Close() + } + if HasBrowserDialer() { + for len(conns) > 0 { + select { + case c := <-conns: + c.Close() + default: + } + } + conns = nil + } if addr != "" { token := uuid.New() csrfToken := token.String() - webpage = bytes.ReplaceAll(webpage, []byte("csrfToken"), []byte(csrfToken)) + webpage := bytes.ReplaceAll(webpage, []byte("csrfToken"), []byte(csrfToken)) conns = make(chan *websocket.Conn, 256) - go http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/websocket" { - if r.URL.Query().Get("token") == csrfToken { - if conn, err := upgrader.Upgrade(w, r, nil); err == nil { - conns <- conn - } else { - errors.LogError(context.Background(), "Browser dialer http upgrade unexpected error") + server = &http.Server{ + Addr: addr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/websocket" { + if r.URL.Query().Get("token") == csrfToken { + if conn, err := upgrader.Upgrade(w, r, nil); err == nil { + conns <- conn + } else { + errors.LogError(context.Background(), "Browser dialer http upgrade unexpected error") + } } + } else { + w.Header().Set("Access-Control-Allow-Origin", "*"); + w.Write(webpage) } - } else { - w.Header().Set("Access-Control-Allow-Origin", "*"); - w.Write(webpage) - } - })) + }), + } + go server.ListenAndServe() } } @@ -194,3 +218,8 @@ func CheckOK(conn *websocket.Conn) error { return nil } + +func init() { + Reload() +} +