diff --git a/app/geodata/download.go b/app/geodata/download.go index cc1498a0f..e24e3e4fd 100644 --- a/app/geodata/download.go +++ b/app/geodata/download.go @@ -2,6 +2,7 @@ package geodata import ( "context" + "crypto/tls" go_errors "errors" "io" "net/http" @@ -9,6 +10,7 @@ import ( "path/filepath" "time" + utls "github.com/refraction-networking/utls" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/platform/filesystem" @@ -16,6 +18,7 @@ import ( "github.com/xtls/xray-core/common/utils" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/tagged" + "golang.org/x/net/http2" ) const idleTimeout = 30 * time.Second @@ -26,8 +29,9 @@ type stage struct { } type downloader struct { - ctx context.Context - client *http.Client + ctx context.Context + httpClient *http.Client + httpsClient *http.Client } type idleConn struct { @@ -53,52 +57,84 @@ func (c *idleConn) Write(b []byte) (int, error) { func newDownloader(ctx context.Context, dispatcher routing.Dispatcher, outbound string) *downloader { return &downloader{ - ctx: ctx, - client: newClient(ctx, dispatcher, outbound), + ctx: ctx, + httpClient: newClient(ctx, dispatcher, outbound, false), + httpsClient: newClient(ctx, dispatcher, outbound, true), } } -func newClient(baseCtx context.Context, dispatcher routing.Dispatcher, outbound string) *http.Client { - return &http.Client{ - Transport: &http.Transport{ - Proxy: nil, - DisableKeepAlives: true, - DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - var conn net.Conn - err := task.Run(ctx, func() error { - if tagged.Dialer == nil { - return errors.New("tagged dialer is not initialized") - } - dest, err := net.ParseDestination(network + ":" + address) - if err != nil { - return errors.New("cannot understand address").Base(err) - } - c, err := tagged.Dialer(baseCtx, dispatcher, dest, outbound) - if err != nil { - return errors.New("cannot dial remote address ", dest).Base(err) - } - conn = c - return nil - }) - if err != nil { - return nil, errors.New("cannot finish connection").Base(err) - } - return &idleConn{ - Conn: conn, - }, nil - }, - TLSHandshakeTimeout: idleTimeout, - ResponseHeaderTimeout: idleTimeout, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if req.URL.Scheme != "https" { - return errors.New("redirected to non-https URL: ", req.URL.String()) +func newClient(baseCtx context.Context, dispatcher routing.Dispatcher, outbound string, isHTTPS bool) *http.Client { + dial := func(ctx context.Context, network, address string) (net.Conn, error) { + var conn net.Conn + err := task.Run(ctx, func() error { + if tagged.Dialer == nil { + return errors.New("tagged dialer is not initialized") } - if len(via) >= 10 { - return errors.New("stopped after 10 redirects") + dest, err := net.ParseDestination(network + ":" + address) + if err != nil { + return errors.New("cannot understand address").Base(err) } + c, err := tagged.Dialer(baseCtx, dispatcher, dest, outbound) + if err != nil { + return errors.New("cannot dial remote address ", dest).Base(err) + } + conn = c return nil - }, + }) + if err != nil { + return nil, errors.New("cannot finish connection").Base(err) + } + return &idleConn{ + Conn: conn, + }, nil + } + if isHTTPS { + return &http.Client{ + Transport: &http2.Transport{ + DialTLSContext: func(ctx context.Context, network string, address string, cfg *tls.Config) (net.Conn, error) { + conn, err := dial(ctx, network, address) + if err != nil { + return nil, err + } + host, _, _ := net.SplitHostPort(address) + tlsConn := utls.UClient(conn, &utls.Config{ServerName: host}, utls.HelloChrome_Auto) + handshakeCtx, cancel := context.WithTimeout(ctx, idleTimeout) + defer cancel() + if err := tlsConn.HandshakeContext(handshakeCtx); err != nil { + conn.Close() + return nil, err + } + return tlsConn, nil + }, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if req.URL.Scheme != "https" { + return errors.New("redirected to non-https URL: ", req.URL.String()) + } + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil + }, + } + } else { + return &http.Client{ + Transport: &http.Transport{ + Proxy: nil, + DisableKeepAlives: true, + DialContext: dial, + ResponseHeaderTimeout: idleTimeout, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if req.URL.Scheme != "https" { + return errors.New("redirected to non-https URL: ", req.URL.String()) + } + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + return nil + }, + } } } @@ -160,7 +196,13 @@ func (d *downloader) fetch(rawURL string, writer io.Writer) error { } utils.TryDefaultHeadersWith(req.Header, "nav") - resp, err := d.client.Do(req) + var client *http.Client + if req.URL.Scheme == "https" { + client = d.httpsClient + } else { + client = d.httpClient + } + resp, err := client.Do(req) if err != nil { return err }