package splithttp import ( "bytes" "context" gotls "crypto/tls" "encoding/base64" "fmt" "io" "net/http" "strconv" "strings" "sync" "time" "github.com/apernet/quic-go" "github.com/apernet/quic-go/http3" goreality "github.com/xtls/reality" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" http_proto "github.com/xtls/xray-core/common/protocol/http" "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/reality" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" ) type requestHandler struct { config *Config host string path string ln *Listener sessionMu *sync.Mutex sessions sync.Map localAddr net.Addr socketSettings *internet.SocketConfig } type httpSession struct { uploadQueue *uploadQueue // for as long as the GET request is not opened by the client, this will be // open ("undone"), and the session may be expired within a certain TTL. // after the client connects, this becomes "done" and the session lives as // long as the GET request. isFullyConnected *done.Instance } func (h *requestHandler) upsertSession(sessionId string) *httpSession { // fast path currentSessionAny, ok := h.sessions.Load(sessionId) if ok { return currentSessionAny.(*httpSession) } // slow path h.sessionMu.Lock() defer h.sessionMu.Unlock() currentSessionAny, ok = h.sessions.Load(sessionId) if ok { return currentSessionAny.(*httpSession) } s := &httpSession{ uploadQueue: NewUploadQueue(h.ln.config.GetNormalizedScMaxBufferedPosts()), isFullyConnected: done.New(), } h.sessions.Store(sessionId, s) shouldReap := done.New() go func() { time.Sleep(30 * time.Second) shouldReap.Close() }() go func() { select { case <-shouldReap.Wait(): h.sessions.Delete(sessionId) s.uploadQueue.Close() case <-s.isFullyConnected.Wait(): } }() return s } func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if len(h.host) > 0 && !internet.IsValidHTTPHost(request.Host, h.host) { errors.LogInfo(context.Background(), "failed to validate host, request:", request.Host, ", config:", h.host) writer.WriteHeader(http.StatusNotFound) return } if !strings.HasPrefix(request.URL.Path, h.path) { errors.LogInfo(context.Background(), "failed to validate path, request:", request.URL.Path, ", config:", h.path) writer.WriteHeader(http.StatusNotFound) return } h.config.WriteResponseHeader(writer) length := int(h.config.GetNormalizedXPaddingBytes().rand()) config := XPaddingConfig{Length: length} if h.config.XPaddingObfsMode { config.Placement = XPaddingPlacement{ Placement: h.config.XPaddingPlacement, Key: h.config.XPaddingKey, Header: h.config.XPaddingHeader, } config.Method = PaddingMethod(h.config.XPaddingMethod) } else { config.Placement = XPaddingPlacement{ Placement: PlacementHeader, Header: "X-Padding", } } h.config.ApplyXPaddingToHeader(writer.Header(), config) /* clientVer := []int{0, 0, 0} x_version := strings.Split(request.URL.Query().Get("x_version"), ".") for j := 0; j < 3 && len(x_version) > j; j++ { clientVer[j], _ = strconv.Atoi(x_version[j]) } */ validRange := h.config.GetNormalizedXPaddingBytes() paddingValue, paddingPlacement := h.config.ExtractXPaddingFromRequest(request, h.config.XPaddingObfsMode) if !h.config.IsPaddingValid(paddingValue, validRange.From, validRange.To, PaddingMethod(h.config.XPaddingMethod)) { errors.LogInfo(context.Background(), "invalid padding ("+paddingPlacement+") length:", int32(len(paddingValue))) writer.WriteHeader(http.StatusBadRequest) return } sessionId, seqStr := h.config.ExtractMetaFromRequest(request, h.path) if sessionId == "" && h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "stream-one" && h.config.Mode != "stream-up" { errors.LogInfo(context.Background(), "stream-one mode is not allowed") writer.WriteHeader(http.StatusBadRequest) return } var forwardedAddrs []net.Address if h.socketSettings != nil && len(h.socketSettings.TrustedXForwardedFor) > 0 { for _, key := range h.socketSettings.TrustedXForwardedFor { if len(request.Header.Values(key)) > 0 { forwardedAddrs = http_proto.ParseXForwardedFor(request.Header) break } } } else { forwardedAddrs = http_proto.ParseXForwardedFor(request.Header) } var remoteAddr net.Addr var err error remoteAddr, err = net.ResolveTCPAddr("tcp", request.RemoteAddr) if err != nil { remoteAddr = &net.TCPAddr{ IP: []byte{0, 0, 0, 0}, Port: 0, } } if request.ProtoMajor == 3 { remoteAddr = &net.UDPAddr{ IP: remoteAddr.(*net.TCPAddr).IP, Port: remoteAddr.(*net.TCPAddr).Port, } } if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { remoteAddr = &net.TCPAddr{ IP: forwardedAddrs[0].IP(), Port: 0, } } var currentSession *httpSession if sessionId != "" { currentSession = h.upsertSession(sessionId) } scMaxEachPostBytes := int(h.ln.config.GetNormalizedScMaxEachPostBytes().To) uplinkHTTPMethod := h.config.GetNormalizedUplinkHTTPMethod() isUplinkRequest := false if uplinkHTTPMethod != "GET" && request.Method == uplinkHTTPMethod { isUplinkRequest = true } uplinkDataPlacement := h.config.GetNormalizedUplinkDataPlacement() uplinkDataKey := h.config.UplinkDataKey switch uplinkDataPlacement { case PlacementHeader: if request.Header.Get(uplinkDataKey+"-Upstream") == "1" { isUplinkRequest = true } case PlacementCookie: if c, _ := request.Cookie(uplinkDataKey + "_upstream"); c != nil && c.Value == "1" { isUplinkRequest = true } } if isUplinkRequest && sessionId != "" { // stream-up, packet-up if seqStr == "" { if h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "stream-up" { errors.LogInfo(context.Background(), "stream-up mode is not allowed") writer.WriteHeader(http.StatusBadRequest) return } httpSC := &httpServerConn{ Instance: done.New(), Reader: request.Body, ResponseWriter: writer, } err = currentSession.uploadQueue.Push(Packet{ Reader: httpSC, }) if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (PushReader)") writer.WriteHeader(http.StatusConflict) } else { writer.Header().Set("X-Accel-Buffering", "no") writer.Header().Set("Cache-Control", "no-store") writer.WriteHeader(http.StatusOK) scStreamUpServerSecs := h.config.GetNormalizedScStreamUpServerSecs() referrer := request.Header.Get("Referer") if referrer != "" && scStreamUpServerSecs.To > 0 { go func() { for { _, err := httpSC.Write(bytes.Repeat([]byte{'X'}, int(h.config.GetNormalizedXPaddingBytes().rand()))) if err != nil { break } time.Sleep(time.Duration(scStreamUpServerSecs.rand()) * time.Second) } }() } select { case <-request.Context().Done(): case <-httpSC.Wait(): } } httpSC.Close() return } if h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "packet-up" { errors.LogInfo(context.Background(), "packet-up mode is not allowed") writer.WriteHeader(http.StatusBadRequest) return } var payload []byte if uplinkDataPlacement != PlacementBody { var encodedStr string switch uplinkDataPlacement { case PlacementHeader: dataLenStr := request.Header.Get(uplinkDataKey + "-Length") if dataLenStr != "" { dataLen, _ := strconv.Atoi(dataLenStr) var chunks []string i := 0 for { chunk := request.Header.Get(fmt.Sprintf("%s-%d", uplinkDataKey, i)) if chunk == "" { break } chunks = append(chunks, chunk) i++ } encodedStr = strings.Join(chunks, "") if len(encodedStr) != dataLen { encodedStr = "" } } case PlacementCookie: var chunks []string i := 0 for { cookieName := fmt.Sprintf("%s_%d", uplinkDataKey, i) if c, _ := request.Cookie(cookieName); c != nil { chunks = append(chunks, c.Value) i++ } else { break } } if len(chunks) > 0 { encodedStr = strings.Join(chunks, "") } } if encodedStr != "" { payload, err = base64.RawURLEncoding.DecodeString(encodedStr) } else { errors.LogInfoInner(context.Background(), err, "failed to extract data from key "+uplinkDataKey+" placed in "+uplinkDataPlacement) writer.WriteHeader(http.StatusInternalServerError) return } } else { payload, err = io.ReadAll(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1)) } if len(payload) > scMaxEachPostBytes { errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.") writer.WriteHeader(http.StatusRequestEntityTooLarge) return } if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (ReadAll)") writer.WriteHeader(http.StatusInternalServerError) return } seq, err := strconv.ParseUint(seqStr, 10, 64) if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (ParseUint)") writer.WriteHeader(http.StatusInternalServerError) return } err = currentSession.uploadQueue.Push(Packet{ Payload: payload, Seq: seq, }) if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (PushPayload)") writer.WriteHeader(http.StatusInternalServerError) return } writer.WriteHeader(http.StatusOK) } else if request.Method == "GET" || sessionId == "" { // stream-down, stream-one if sessionId != "" { // after GET is done, the connection is finished. disable automatic // session reaping, and handle it in defer currentSession.isFullyConnected.Close() defer h.sessions.Delete(sessionId) } // magic header instructs nginx + apache to not buffer response body writer.Header().Set("X-Accel-Buffering", "no") // A web-compliant header telling all middleboxes to disable caching. // Should be able to prevent overloading the cache, or stop CDNs from // teeing the response stream into their cache, causing slowdowns. writer.Header().Set("Cache-Control", "no-store") if !h.config.NoSSEHeader { // magic header to make the HTTP middle box consider this as SSE to disable buffer writer.Header().Set("Content-Type", "text/event-stream") } writer.WriteHeader(http.StatusOK) writer.(http.Flusher).Flush() httpSC := &httpServerConn{ Instance: done.New(), Reader: request.Body, ResponseWriter: writer, } conn := splitConn{ writer: httpSC, reader: httpSC, remoteAddr: remoteAddr, localAddr: h.localAddr, } if sessionId != "" { // if not stream-one conn.reader = currentSession.uploadQueue } h.ln.addConn(stat.Connection(&conn)) // "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned." select { case <-request.Context().Done(): case <-httpSC.Wait(): } conn.Close() } else { errors.LogInfo(context.Background(), "unsupported method: ", request.Method) writer.WriteHeader(http.StatusMethodNotAllowed) } } type httpServerConn struct { sync.Mutex *done.Instance io.Reader // no need to Close request.Body http.ResponseWriter } func (c *httpServerConn) Write(b []byte) (int, error) { c.Lock() defer c.Unlock() if c.Done() { return 0, io.ErrClosedPipe } n, err := c.ResponseWriter.Write(b) if err == nil { c.ResponseWriter.(http.Flusher).Flush() } return n, err } func (c *httpServerConn) Close() error { c.Lock() defer c.Unlock() return c.Instance.Close() } type Listener struct { sync.Mutex server http.Server h3server *http3.Server listener net.Listener h3listener *quic.EarlyListener config *Config addConn internet.ConnHandler isH3 bool } func ListenXH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { l := &Listener{ addConn: addConn, } l.config = streamSettings.ProtocolSettings.(*Config) if l.config != nil { if streamSettings.SocketSettings == nil { streamSettings.SocketSettings = &internet.SocketConfig{} } } handler := &requestHandler{ config: l.config, host: l.config.Host, path: l.config.GetNormalizedPath(), ln: l, sessionMu: &sync.Mutex{}, sessions: sync.Map{}, socketSettings: streamSettings.SocketSettings, } tlsConfig := getTLSConfig(streamSettings) l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3" var err error if port == net.Port(0) { // unix l.listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ Name: address.Domain(), Net: "unix", }, streamSettings.SocketSettings) if err != nil { return nil, errors.New("failed to listen UNIX domain socket for XHTTP on ", address).Base(err) } errors.LogInfo(ctx, "listening UNIX domain socket for XHTTP on ", address) } else if l.isH3 { // quic Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{ IP: address.IP(), Port: int(port), }, streamSettings.SocketSettings) if err != nil { return nil, errors.New("failed to listen UDP for XHTTP/3 on ", address, ":", port).Base(err) } l.h3listener, err = quic.ListenEarly(Conn, tlsConfig, nil) if err != nil { return nil, errors.New("failed to listen QUIC for XHTTP/3 on ", address, ":", port).Base(err) } errors.LogInfo(ctx, "listening QUIC for XHTTP/3 on ", address, ":", port) handler.localAddr = l.h3listener.Addr() l.h3server = &http3.Server{ Handler: handler, } go func() { if err := l.h3server.ServeListener(l.h3listener); err != nil { errors.LogErrorInner(ctx, err, "failed to serve HTTP/3 for XHTTP/3") } }() } else { // tcp l.listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ IP: address.IP(), Port: int(port), }, streamSettings.SocketSettings) if err != nil { return nil, errors.New("failed to listen TCP for XHTTP on ", address, ":", port).Base(err) } errors.LogInfo(ctx, "listening TCP for XHTTP on ", address, ":", port) } // tcp/unix (h1/h2) if l.listener != nil { if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { l.listener = gotls.NewListener(l.listener, tlsConfig) } } if config := reality.ConfigFromStreamSettings(streamSettings); config != nil { l.listener = goreality.NewListener(l.listener, config.GetREALITYConfig()) } handler.localAddr = l.listener.Addr() // server can handle both plaintext HTTP/1.1 and h2c protocols := new(http.Protocols) protocols.SetHTTP1(true) protocols.SetUnencryptedHTTP2(true) l.server = http.Server{ Handler: handler, ReadHeaderTimeout: time.Second * 4, MaxHeaderBytes: 8192, Protocols: protocols, } go func() { if err := l.server.Serve(l.listener); err != nil { errors.LogErrorInner(ctx, err, "failed to serve HTTP for XHTTP") } }() } return l, err } // Addr implements net.Listener.Addr(). func (ln *Listener) Addr() net.Addr { if ln.h3listener != nil { return ln.h3listener.Addr() } if ln.listener != nil { return ln.listener.Addr() } return nil } // Close implements net.Listener.Close(). func (ln *Listener) Close() error { if ln.h3server != nil { if err := ln.h3server.Close(); err != nil { return err } } else if ln.listener != nil { return ln.listener.Close() } return errors.New("listener does not have an HTTP/3 server or a net.listener") } func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *gotls.Config { config := tls.ConfigFromStreamSettings(streamSettings) if config == nil { return &gotls.Config{} } return config.GetTLSConfig() } func init() { common.Must(internet.RegisterTransportListener(protocolName, ListenXH)) }