more compatibility

This commit is contained in:
Meo597
2026-05-19 03:38:36 +08:00
parent 81bcfcae77
commit d39965f20d
2 changed files with 112 additions and 88 deletions
+7
View File
@@ -274,6 +274,13 @@ func (h *Handler) SocketSettings() *internet.SocketConfig {
return h.streamSettings.SocketSettings return h.streamSettings.SocketSettings
} }
func (h *Handler) UsesProxySettings() bool {
if h.senderSettings != nil && h.senderSettings.ProxySettings.HasTag() {
return true
}
return h.streamSettings != nil && h.streamSettings.SocketSettings != nil && len(h.streamSettings.SocketSettings.DialerProxy) > 0
}
// Dial implements internet.Dialer. // Dial implements internet.Dialer.
func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) { func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) {
if h.senderSettings != nil { if h.senderSettings != nil {
+105 -88
View File
@@ -43,6 +43,9 @@ func init() {
h.socketStrategy = sockopt.DomainStrategy h.socketStrategy = sockopt.DomainStrategy
} }
} }
if handler, ok := session.FullHandlerFromContext(ctx).(handlerWithProxySettings); ok {
h.usesProxySettings = handler.UsesProxySettings()
}
if err := core.RequireFeatures(ctx, func(pm policy.Manager) error { if err := core.RequireFeatures(ctx, func(pm policy.Manager) error {
return h.Init(config.(*Config), pm) return h.Init(config.(*Config), pm)
}); err != nil { }); err != nil {
@@ -97,6 +100,10 @@ type handlerWithSocketSettings interface {
SocketSettings() *internet.SocketConfig SocketSettings() *internet.SocketConfig
} }
type handlerWithProxySettings interface {
UsesProxySettings() bool
}
type FinalRule struct { type FinalRule struct {
action RuleAction action RuleAction
network [8]bool network [8]bool
@@ -107,10 +114,11 @@ type FinalRule struct {
// Handler handles Freedom connections. // Handler handles Freedom connections.
type Handler struct { type Handler struct {
policyManager policy.Manager policyManager policy.Manager
config *Config config *Config
finalRules []*FinalRule finalRules []*FinalRule
socketStrategy internet.DomainStrategy socketStrategy internet.DomainStrategy
usesProxySettings bool
} }
func buildFinalRule(config *FinalRuleConfig) (*FinalRule, error) { func buildFinalRule(config *FinalRuleConfig) (*FinalRule, error) {
@@ -187,22 +195,6 @@ func getDefaultFinalRule(inbound *session.Inbound) *FinalRule {
return nil return nil
} }
func (h *Handler) shouldResolveDomainBeforeFinalRules(dialDest net.Destination, defaultRule *FinalRule) bool {
if !dialDest.Address.Family().IsDomain() {
return false
}
if len(h.finalRules) > 0 {
rule := h.finalRules[0]
if rule.action == RuleAction_Allow && rule.network[dialDest.Network] && len(rule.port) == 0 && rule.ip == nil {
return false
}
}
if defaultRule != nil || len(h.finalRules) > 0 {
return true
}
return false
}
func (h *Handler) matchFinalRule(network net.Network, address net.Address, port net.Port, defaultRule *FinalRule) *FinalRule { func (h *Handler) matchFinalRule(network net.Network, address net.Address, port net.Port, defaultRule *FinalRule) *FinalRule {
for _, rule := range h.finalRules { for _, rule := range h.finalRules {
if rule.Apply(network, address, port) { if rule.Apply(network, address, port) {
@@ -215,13 +207,6 @@ func (h *Handler) matchFinalRule(network net.Network, address net.Address, port
return nil return nil
} }
func (h *Handler) applyFinalRules(network net.Network, address net.Address, port net.Port, defaultRule *FinalRule) RuleAction {
if rule := h.matchFinalRule(network, address, port, defaultRule); rule != nil {
return rule.action
}
return RuleAction_Allow
}
// Init initializes the Handler with necessary parameters. // Init initializes the Handler with necessary parameters.
func (h *Handler) Init(config *Config, pm policy.Manager) error { func (h *Handler) Init(config *Config, pm policy.Manager) error {
h.config = config h.config = config
@@ -256,6 +241,20 @@ func (h *Handler) blockDelay(rule *FinalRule) time.Duration {
return time.Duration(min+uint64(dice.Roll(int(span+1)))) * time.Second return time.Duration(min+uint64(dice.Roll(int(span+1)))) * time.Second
} }
func (h *Handler) blackhole(ctx context.Context, input buf.Reader, output buf.Writer, rule *FinalRule, dest *net.Destination) error {
delay := h.blockDelay(rule)
errors.LogInfo(ctx, "blocked target: ", *dest, ", blackholing connection for ", delay)
timer := time.AfterFunc(delay, func() {
common.Interrupt(input)
common.Interrupt(output)
errors.LogInfo(ctx, "closed blackholed connection to blocked target: ", *dest)
})
defer timer.Stop()
defer common.Close(output)
_ = buf.Copy(input, buf.Discard)
return nil
}
func (h *Handler) udpDomainStrategy() internet.DomainStrategy { func (h *Handler) udpDomainStrategy() internet.DomainStrategy {
if h.config.DomainStrategy.HasStrategy() { if h.config.DomainStrategy.HasStrategy() {
return h.config.DomainStrategy return h.config.DomainStrategy
@@ -310,58 +309,75 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
var conn stat.Connection var conn stat.Connection
var blockedDest *net.Destination var blockedDest *net.Destination
var blockedRule *FinalRule var blockedRule *FinalRule
firstResolve := true
err := retry.ExponentialBackoff(5, 100).On(func() error { err := retry.ExponentialBackoff(5, 100).On(func() error {
dialDest := destination dialDest := destination
if h.config.DomainStrategy.HasStrategy() && dialDest.Address.Family().IsDomain() {
strategy := h.config.DomainStrategy if dialDest.Address.Family().IsDomain() {
if destination.Network == net.Network_UDP && origTargetAddr != nil && outGateway == nil { if strategy := h.config.DomainStrategy; strategy.HasStrategy() {
strategy = strategy.GetDynamicStrategy(origTargetAddr.Family()) if destination.Network == net.Network_UDP && origTargetAddr != nil && outGateway == nil {
} strategy = strategy.GetDynamicStrategy(origTargetAddr.Family())
ips, err := internet.LookupForIP(dialDest.Address.Domain(), strategy, outGateway)
if err != nil {
errors.LogInfoInner(ctx, err, "failed to get IP address for domain ", dialDest.Address.Domain())
if h.config.DomainStrategy.ForceIP() || h.shouldResolveDomainBeforeFinalRules(dialDest, defaultRule) {
return err
} }
} else { ips, err := internet.LookupForIP(dialDest.Address.Domain(), strategy, outGateway)
dialDest = net.Destination{ if err != nil { // SRV/TXT
Network: dialDest.Network, errors.LogInfoInner(ctx, err, "failed to get IP address for domain ", dialDest.Address.Domain())
Address: net.IPAddress(ips[dice.Roll(len(ips))]), if h.config.DomainStrategy.ForceIP() || defaultRule != nil || len(h.finalRules) > 0 {
Port: dialDest.Port, return err // retry
}
} else { // to ip
dialDest = net.Destination{
Network: dialDest.Network,
Address: net.IPAddress(ips[dice.Roll(len(ips))]),
Port: dialDest.Port,
}
errors.LogInfo(ctx, "dialing to ", dialDest)
if rule := h.matchFinalRule(dialDest.Network, dialDest.Address, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedRule = rule
return nil
}
} }
errors.LogInfo(ctx, "dialing to ", dialDest) } else if defaultRule != nil || len(h.finalRules) > 0 { // freedom asis + hasrules
} if strategy := h.socketStrategy; strategy.HasStrategy() {
} else if h.shouldResolveDomainBeforeFinalRules(dialDest, defaultRule) { // asis + domain + hasrules ips, err := internet.LookupForIP(dialDest.Address.Domain(), strategy, outGateway)
domain := dialDest.Address.Domain() if err != nil { // SRV/TXT
var ips []net.IP errors.LogInfoInner(ctx, err, "failed to get IP address for domain ", dialDest.Address.Domain())
if firstResolve { if strategy.ForceIP() {
firstResolve = false return err // retry
supportIPv4, supportIPv6 := utils.CheckRoutes() }
if supportIPv4 { }
ips, _ = net.DefaultResolver.LookupIP(ctx, "ip4", domain) for _, ip := range ips {
if addr := net.IPAddress(ip); addr != nil {
if rule := h.matchFinalRule(dialDest.Network, addr, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedDest.Address = addr
blockedRule = rule
return nil
}
}
}
} else { // sockopt asis
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, dialDest.Address.Domain())
if err != nil { // SRV/TXT
errors.LogInfoInner(ctx, err, "failed to get IP address for domain ", dialDest.Address.Domain())
}
for _, addr := range addrs {
if ipAddr := net.IPAddress(addr.IP); ipAddr != nil {
if rule := h.matchFinalRule(dialDest.Network, ipAddr, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedDest.Address = ipAddr
blockedRule = rule
return nil
}
}
}
} }
if len(ips) == 0 && supportIPv6 {
ips, _ = net.DefaultResolver.LookupIP(ctx, "ip6", domain)
}
if len(ips) == 0 {
return errors.New("failed to get IP address for domain ", domain)
}
} else {
ips, _ = net.DefaultResolver.LookupIP(ctx, "ip", domain)
} }
if len(ips) == 0 { // SRV/TXT, lookup failed } else {
return errors.New("failed to get IP address for domain ", domain) if rule := h.matchFinalRule(dialDest.Network, dialDest.Address, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedRule = rule
return nil
} }
if addr := net.IPAddress(ips[dice.Roll(len(ips))]); addr != nil {
dialDest.Address = addr
errors.LogInfo(ctx, "dialing to ", dialDest)
}
}
if rule := h.matchFinalRule(dialDest.Network, dialDest.Address, dialDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
blockedDest = &dialDest
blockedRule = rule
return nil
} }
rawConn, err := dialer.Dial(ctx, dialDest) rawConn, err := dialer.Dial(ctx, dialDest)
@@ -376,20 +392,21 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
return errors.New("failed to open connection to ", destination).Base(err) return errors.New("failed to open connection to ", destination).Base(err)
} }
if blockedDest != nil { if blockedDest != nil {
delay := h.blockDelay(blockedRule) return h.blackhole(ctx, input, output, blockedRule, blockedDest)
errors.LogInfo(ctx, "blocked target: ", *blockedDest, ", blackholing connection for ", delay)
timer := time.AfterFunc(delay, func() {
common.Interrupt(input)
common.Interrupt(output)
errors.LogInfo(ctx, "closed blackholed connection to blocked target: ", *blockedDest)
})
defer timer.Stop()
defer common.Close(output)
if err := buf.Copy(input, buf.Discard); err != nil {
return nil
}
return nil
} }
if defaultRule != nil || len(h.finalRules) > 0 {
if h.usesProxySettings {
errors.LogInfo(ctx, "skipping final rule check for proxied remote endpoint, original target: ", destination)
} else {
// SRV/TXT, lookup failed
remoteDest := net.DestinationFromAddr(conn.RemoteAddr())
if rule := h.matchFinalRule(remoteDest.Network, remoteDest.Address, remoteDest.Port, defaultRule); rule != nil && rule.action == RuleAction_Block {
conn.Close()
return h.blackhole(ctx, input, output, rule, &remoteDest)
}
}
}
if h.config.ProxyProtocol > 0 && h.config.ProxyProtocol <= 2 { if h.config.ProxyProtocol > 0 && h.config.ProxyProtocol <= 2 {
version := byte(h.config.ProxyProtocol) version := byte(h.config.ProxyProtocol)
srcAddr := inbound.Source.RawNetAddr() srcAddr := inbound.Source.RawNetAddr()
@@ -538,7 +555,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
} }
udpAddr := d.(*net.UDPAddr) udpAddr := d.(*net.UDPAddr)
sourceAddr := net.IPAddress(udpAddr.IP) sourceAddr := net.IPAddress(udpAddr.IP)
if r.Handler.applyFinalRules(net.Network_UDP, sourceAddr, net.Port(udpAddr.Port), r.DefaultRule) == RuleAction_Block { if rule := r.Handler.matchFinalRule(net.Network_UDP, sourceAddr, net.Port(udpAddr.Port), r.DefaultRule); rule != nil && rule.action == RuleAction_Block {
continue continue
} }
b.Resize(0, int32(n)) b.Resize(0, int32(n))
@@ -631,7 +648,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
} else { } else {
shouldUseSystemResolver := true shouldUseSystemResolver := true
if resolveStrategy := w.Handler.udpDomainStrategy(); resolveStrategy.HasStrategy() { if resolveStrategy := w.Handler.udpDomainStrategy(); resolveStrategy.HasStrategy() {
ips, err := internet.LookupForIP(b.UDP.Address.Domain(), w.Handler.config.DomainStrategy, w.OutGateway) ips, err := internet.LookupForIP(b.UDP.Address.Domain(), resolveStrategy, w.OutGateway)
if err != nil { if err != nil {
// drop packet if resolve failed when forceIP // drop packet if resolve failed when forceIP
if resolveStrategy.ForceIP() { if resolveStrategy.ForceIP() {
@@ -657,7 +674,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
} }
} }
} }
if w.applyFinalRules(net.Network_UDP, b.UDP.Address, b.UDP.Port, w.DefaultRule) == RuleAction_Block { if rule := w.matchFinalRule(net.Network_UDP, b.UDP.Address, b.UDP.Port, w.DefaultRule); rule != nil && rule.action == RuleAction_Block {
b.Release() b.Release()
continue continue
} }