mirror of
https://github.com/XTLS/Xray-core.git
synced 2026-05-14 10:00:34 +00:00
Proxy: Add Hysteria outbound & transport (version 2, udphop) and Salamander udpmask (#5508)
https://github.com/XTLS/Xray-core/issues/3547#issuecomment-3549896520 https://github.com/XTLS/Xray-core/issues/2635#issuecomment-3570871754
This commit is contained in:
@@ -0,0 +1,263 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"context"
|
||||
go_errors "errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
|
||||
"github.com/apernet/quic-go"
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet"
|
||||
"github.com/xtls/xray-core/transport/internet/hysteria"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
server *protocol.ServerSpec
|
||||
policyManager policy.Manager
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
|
||||
if config.Server == nil {
|
||||
return nil, errors.New(`no target server found`)
|
||||
}
|
||||
server, err := protocol.NewServerSpecFromPB(config.Server)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get server spec").Base(err)
|
||||
}
|
||||
|
||||
v := core.MustFromContext(ctx)
|
||||
client := &Client{
|
||||
server: server,
|
||||
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds)-1]
|
||||
if !ob.Target.IsValid() {
|
||||
return errors.New("target not specified")
|
||||
}
|
||||
ob.Name = "hysteria"
|
||||
ob.CanSpliceCopy = 3
|
||||
target := ob.Target
|
||||
|
||||
conn, err := dialer.Dial(ctx, c.server.Destination)
|
||||
if err != nil {
|
||||
return errors.New("failed to find an available destination").AtWarning().Base(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
errors.LogInfo(ctx, "tunneling request to ", target, " via ", target.Network, ":", c.server.Destination.NetAddr())
|
||||
|
||||
var newCtx context.Context
|
||||
var newCancel context.CancelFunc
|
||||
if session.TimeoutOnlyFromContext(ctx) {
|
||||
newCtx, newCancel = context.WithCancel(context.Background())
|
||||
}
|
||||
|
||||
sessionPolicy := c.policyManager.ForLevel(0)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
timer := signal.CancelAfterInactivity(ctx, func() {
|
||||
cancel()
|
||||
if newCancel != nil {
|
||||
newCancel()
|
||||
}
|
||||
}, sessionPolicy.Timeouts.ConnectionIdle)
|
||||
|
||||
if newCtx != nil {
|
||||
ctx = newCtx
|
||||
}
|
||||
|
||||
if target.Network == net.Network_TCP {
|
||||
requestDone := func() error {
|
||||
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
|
||||
bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
|
||||
err := WriteTCPRequest(bufferedWriter, target.NetAddr())
|
||||
if err != nil {
|
||||
return errors.New("failed to write request").Base(err)
|
||||
}
|
||||
if err := bufferedWriter.SetBuffered(false); err != nil {
|
||||
return err
|
||||
}
|
||||
return buf.Copy(link.Reader, bufferedWriter, buf.UpdateActivity(timer))
|
||||
}
|
||||
|
||||
responseDone := func() error {
|
||||
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
|
||||
ok, msg, err := ReadTCPResponse(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return errors.New(msg)
|
||||
}
|
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||
}
|
||||
|
||||
responseDoneAndCloseWriter := task.OnSuccess(responseDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
|
||||
return errors.New("connection ends").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if target.Network == net.Network_UDP {
|
||||
iConn := stat.TryUnwrapStatsConn(conn)
|
||||
_, ok := iConn.(*hysteria.InterUdpConn)
|
||||
if !ok {
|
||||
return errors.New("udp requires hysteria udp transport")
|
||||
}
|
||||
|
||||
requestDone := func() error {
|
||||
defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
|
||||
|
||||
writer := &UDPWriter{
|
||||
Writer: conn,
|
||||
buf: make([]byte, MaxUDPSize),
|
||||
addr: target.NetAddr(),
|
||||
}
|
||||
|
||||
if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return errors.New("failed to transport all UDP request").Base(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
responseDone := func() error {
|
||||
defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
|
||||
|
||||
reader := &UDPReader{
|
||||
Reader: conn,
|
||||
df: &Defragger{},
|
||||
}
|
||||
|
||||
if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return errors.New("failed to transport all UDP response").Base(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
responseDoneAndCloseWriter := task.OnSuccess(responseDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
|
||||
return errors.New("connection ends").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||
return NewClient(ctx, config.(*ClientConfig))
|
||||
}))
|
||||
}
|
||||
|
||||
type UDPWriter struct {
|
||||
Writer io.Writer
|
||||
buf []byte
|
||||
addr string
|
||||
}
|
||||
|
||||
func (w *UDPWriter) sendMsg(msg *UDPMessage) error {
|
||||
msgN := msg.Serialize(w.buf)
|
||||
if msgN < 0 {
|
||||
// Message larger than buffer, silent drop
|
||||
return nil
|
||||
}
|
||||
_, err := w.Writer.Write(w.buf[:msgN])
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
for {
|
||||
mb2, b := buf.SplitFirst(mb)
|
||||
mb = mb2
|
||||
if b == nil {
|
||||
break
|
||||
}
|
||||
addr := w.addr
|
||||
if b.UDP != nil {
|
||||
addr = b.UDP.NetAddr()
|
||||
}
|
||||
msg := &UDPMessage{
|
||||
SessionID: 0,
|
||||
PacketID: 0,
|
||||
FragID: 0,
|
||||
FragCount: 1,
|
||||
Addr: addr,
|
||||
Data: b.Bytes(),
|
||||
}
|
||||
if err := w.sendMsg(msg); err != nil {
|
||||
var errTooLarge *quic.DatagramTooLargeError
|
||||
if go_errors.As(err, &errTooLarge) {
|
||||
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
||||
fMsgs := FragUDPMessage(msg, int(errTooLarge.MaxDatagramPayloadSize))
|
||||
for _, fMsg := range fMsgs {
|
||||
err := w.sendMsg(&fMsg)
|
||||
if err != nil {
|
||||
b.Release()
|
||||
buf.ReleaseMulti(mb)
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
b.Release()
|
||||
buf.ReleaseMulti(mb)
|
||||
return err
|
||||
}
|
||||
}
|
||||
b.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type UDPReader struct {
|
||||
Reader io.Reader
|
||||
df *Defragger
|
||||
}
|
||||
|
||||
func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
for {
|
||||
b := buf.New()
|
||||
_, err := b.ReadFrom(r.Reader)
|
||||
if err != nil {
|
||||
b.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := ParseUDPMessage(b.Bytes())
|
||||
if err != nil {
|
||||
b.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
dfMsg := r.df.Feed(msg)
|
||||
if dfMsg == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
dest, _ := net.ParseDestination("udp:" + dfMsg.Addr)
|
||||
|
||||
buffer := buf.New()
|
||||
buffer.Write(dfMsg.Data)
|
||||
buffer.UDP = &dest
|
||||
|
||||
return buf.MultiBuffer{buffer}, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"github.com/xtls/xray-core/transport/internet/hysteria/padding"
|
||||
)
|
||||
|
||||
var (
|
||||
tcpRequestPadding = padding.Padding{Min: 64, Max: 512}
|
||||
// tcpResponsePadding = padding.Padding{Min: 128, Max: 1024}
|
||||
)
|
||||
@@ -0,0 +1,126 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.10
|
||||
// protoc v6.33.1
|
||||
// source: proxy/hysteria/config.proto
|
||||
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
protocol "github.com/xtls/xray-core/common/protocol"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Server *protocol.ServerEndpoint `protobuf:"bytes,1,opt,name=server,proto3" json:"server,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ClientConfig) Reset() {
|
||||
*x = ClientConfig{}
|
||||
mi := &file_proxy_hysteria_config_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ClientConfig) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ClientConfig) ProtoMessage() {}
|
||||
|
||||
func (x *ClientConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_proxy_hysteria_config_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ClientConfig.ProtoReflect.Descriptor instead.
|
||||
func (*ClientConfig) Descriptor() ([]byte, []int) {
|
||||
return file_proxy_hysteria_config_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *ClientConfig) GetServer() *protocol.ServerEndpoint {
|
||||
if x != nil {
|
||||
return x.Server
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_proxy_hysteria_config_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_proxy_hysteria_config_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\"L\n" +
|
||||
"\fClientConfig\x12<\n" +
|
||||
"\x06server\x18\x01 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06serverB[\n" +
|
||||
"\x17com.xray.proxy.hysteriaP\x01Z(github.com/xtls/xray-core/proxy/hysteria\xaa\x02\x13Xray.Proxy.Hysteriab\x06proto3"
|
||||
|
||||
var (
|
||||
file_proxy_hysteria_config_proto_rawDescOnce sync.Once
|
||||
file_proxy_hysteria_config_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_proxy_hysteria_config_proto_rawDescGZIP() []byte {
|
||||
file_proxy_hysteria_config_proto_rawDescOnce.Do(func() {
|
||||
file_proxy_hysteria_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proxy_hysteria_config_proto_rawDesc), len(file_proxy_hysteria_config_proto_rawDesc)))
|
||||
})
|
||||
return file_proxy_hysteria_config_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_proxy_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_proxy_hysteria_config_proto_goTypes = []any{
|
||||
(*ClientConfig)(nil), // 0: xray.proxy.hysteria.ClientConfig
|
||||
(*protocol.ServerEndpoint)(nil), // 1: xray.common.protocol.ServerEndpoint
|
||||
}
|
||||
var file_proxy_hysteria_config_proto_depIdxs = []int32{
|
||||
1, // 0: xray.proxy.hysteria.ClientConfig.server:type_name -> xray.common.protocol.ServerEndpoint
|
||||
1, // [1:1] is the sub-list for method output_type
|
||||
1, // [1:1] is the sub-list for method input_type
|
||||
1, // [1:1] is the sub-list for extension type_name
|
||||
1, // [1:1] is the sub-list for extension extendee
|
||||
0, // [0:1] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_proxy_hysteria_config_proto_init() }
|
||||
func file_proxy_hysteria_config_proto_init() {
|
||||
if File_proxy_hysteria_config_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_hysteria_config_proto_rawDesc), len(file_proxy_hysteria_config_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_proxy_hysteria_config_proto_goTypes,
|
||||
DependencyIndexes: file_proxy_hysteria_config_proto_depIdxs,
|
||||
MessageInfos: file_proxy_hysteria_config_proto_msgTypes,
|
||||
}.Build()
|
||||
File_proxy_hysteria_config_proto = out.File
|
||||
file_proxy_hysteria_config_proto_goTypes = nil
|
||||
file_proxy_hysteria_config_proto_depIdxs = nil
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package xray.proxy.hysteria;
|
||||
option csharp_namespace = "Xray.Proxy.Hysteria";
|
||||
option go_package = "github.com/xtls/xray-core/proxy/hysteria";
|
||||
option java_package = "com.xray.proxy.hysteria";
|
||||
option java_multiple_files = true;
|
||||
|
||||
import "common/protocol/server_spec.proto";
|
||||
|
||||
message ClientConfig {
|
||||
xray.common.protocol.ServerEndpoint server = 1;
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package hysteria
|
||||
|
||||
func FragUDPMessage(m *UDPMessage, maxSize int) []UDPMessage {
|
||||
if m.Size() <= maxSize {
|
||||
return []UDPMessage{*m}
|
||||
}
|
||||
fullPayload := m.Data
|
||||
maxPayloadSize := maxSize - m.HeaderSize()
|
||||
off := 0
|
||||
fragID := uint8(0)
|
||||
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
|
||||
frags := make([]UDPMessage, fragCount)
|
||||
for off < len(fullPayload) {
|
||||
payloadSize := len(fullPayload) - off
|
||||
if payloadSize > maxPayloadSize {
|
||||
payloadSize = maxPayloadSize
|
||||
}
|
||||
frag := *m
|
||||
frag.FragID = fragID
|
||||
frag.FragCount = fragCount
|
||||
frag.Data = fullPayload[off : off+payloadSize]
|
||||
frags[fragID] = frag
|
||||
off += payloadSize
|
||||
fragID++
|
||||
}
|
||||
return frags
|
||||
}
|
||||
|
||||
// Defragger handles the defragmentation of UDP messages.
|
||||
// The current implementation can only handle one packet ID at a time.
|
||||
// If another packet arrives before a packet has received all fragments
|
||||
// in their entirety, any previous state is discarded.
|
||||
type Defragger struct {
|
||||
pktID uint16
|
||||
frags []*UDPMessage
|
||||
count uint8
|
||||
size int // data size
|
||||
}
|
||||
|
||||
func (d *Defragger) Feed(m *UDPMessage) *UDPMessage {
|
||||
if m.FragCount <= 1 {
|
||||
return m
|
||||
}
|
||||
if m.FragID >= m.FragCount {
|
||||
// wtf is this?
|
||||
return nil
|
||||
}
|
||||
if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) {
|
||||
// new message, clear previous state
|
||||
d.pktID = m.PacketID
|
||||
d.frags = make([]*UDPMessage, m.FragCount)
|
||||
d.frags[m.FragID] = m
|
||||
d.count = 1
|
||||
d.size = len(m.Data)
|
||||
} else if d.frags[m.FragID] == nil {
|
||||
d.frags[m.FragID] = m
|
||||
d.count++
|
||||
d.size += len(m.Data)
|
||||
if int(d.count) == len(d.frags) {
|
||||
// all fragments received, assemble
|
||||
data := make([]byte, d.size)
|
||||
off := 0
|
||||
for _, frag := range d.frags {
|
||||
off += copy(data[off:], frag.Data)
|
||||
}
|
||||
m.Data = data
|
||||
m.FragID = 0
|
||||
m.FragCount = 1
|
||||
return m
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/apernet/quic-go/quicvarint"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
FrameTypeTCPRequest = 0x401
|
||||
|
||||
// Max length values are for preventing DoS attacks
|
||||
|
||||
MaxAddressLength = 2048
|
||||
MaxMessageLength = 2048
|
||||
MaxPaddingLength = 4096
|
||||
|
||||
MaxUDPSize = 4096
|
||||
|
||||
maxVarInt1 = 63
|
||||
maxVarInt2 = 16383
|
||||
maxVarInt4 = 1073741823
|
||||
maxVarInt8 = 4611686018427387903
|
||||
)
|
||||
|
||||
// TCPRequest format:
|
||||
// 0x401 (QUIC varint)
|
||||
// Address length (QUIC varint)
|
||||
// Address (bytes)
|
||||
// Padding length (QUIC varint)
|
||||
// Padding (bytes)
|
||||
|
||||
func WriteTCPRequest(w io.Writer, addr string) error {
|
||||
padding := tcpRequestPadding.String()
|
||||
paddingLen := len(padding)
|
||||
addrLen := len(addr)
|
||||
sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
|
||||
int(quicvarint.Len(uint64(addrLen))) + addrLen +
|
||||
int(quicvarint.Len(uint64(paddingLen))) + paddingLen
|
||||
buf := make([]byte, sz)
|
||||
i := varintPut(buf, FrameTypeTCPRequest)
|
||||
i += varintPut(buf[i:], uint64(addrLen))
|
||||
i += copy(buf[i:], addr)
|
||||
i += varintPut(buf[i:], uint64(paddingLen))
|
||||
copy(buf[i:], padding)
|
||||
_, err := w.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
// TCPResponse format:
|
||||
// Status (byte, 0=ok, 1=error)
|
||||
// Message length (QUIC varint)
|
||||
// Message (bytes)
|
||||
// Padding length (QUIC varint)
|
||||
// Padding (bytes)
|
||||
|
||||
func ReadTCPResponse(r io.Reader) (bool, string, error) {
|
||||
var status [1]byte
|
||||
if _, err := io.ReadFull(r, status[:]); err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
bReader := quicvarint.NewReader(r)
|
||||
msgLen, err := quicvarint.Read(bReader)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if msgLen > MaxMessageLength {
|
||||
return false, "", errors.New("invalid message length")
|
||||
}
|
||||
var msgBuf []byte
|
||||
// No message is fine
|
||||
if msgLen > 0 {
|
||||
msgBuf = make([]byte, msgLen)
|
||||
_, err = io.ReadFull(r, msgBuf)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
}
|
||||
paddingLen, err := quicvarint.Read(bReader)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if paddingLen > MaxPaddingLength {
|
||||
return false, "", errors.New("invalid padding length")
|
||||
}
|
||||
if paddingLen > 0 {
|
||||
_, err = io.CopyN(io.Discard, r, int64(paddingLen))
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
}
|
||||
return status[0] == 0, string(msgBuf), nil
|
||||
}
|
||||
|
||||
// UDPMessage format:
|
||||
// Session ID (uint32 BE)
|
||||
// Packet ID (uint16 BE)
|
||||
// Fragment ID (uint8)
|
||||
// Fragment count (uint8)
|
||||
// Address length (QUIC varint)
|
||||
// Address (bytes)
|
||||
// Data...
|
||||
|
||||
type UDPMessage struct {
|
||||
SessionID uint32 // 4
|
||||
PacketID uint16 // 2
|
||||
FragID uint8 // 1
|
||||
FragCount uint8 // 1
|
||||
Addr string // varint + bytes
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (m *UDPMessage) HeaderSize() int {
|
||||
lAddr := len(m.Addr)
|
||||
return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
|
||||
}
|
||||
|
||||
func (m *UDPMessage) Size() int {
|
||||
return m.HeaderSize() + len(m.Data)
|
||||
}
|
||||
|
||||
func (m *UDPMessage) Serialize(buf []byte) int {
|
||||
// Make sure the buffer is big enough
|
||||
if len(buf) < m.Size() {
|
||||
return -1
|
||||
}
|
||||
// binary.BigEndian.PutUint32(buf, m.SessionID)
|
||||
binary.BigEndian.PutUint16(buf[4:], m.PacketID)
|
||||
buf[6] = m.FragID
|
||||
buf[7] = m.FragCount
|
||||
i := varintPut(buf[8:], uint64(len(m.Addr)))
|
||||
i += copy(buf[8+i:], m.Addr)
|
||||
i += copy(buf[8+i:], m.Data)
|
||||
return 8 + i
|
||||
}
|
||||
|
||||
func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
|
||||
m := &UDPMessage{}
|
||||
buf := bytes.NewBuffer(msg)
|
||||
if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lAddr, err := quicvarint.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lAddr == 0 || lAddr > MaxMessageLength {
|
||||
return nil, errors.New("invalid address length")
|
||||
}
|
||||
bs := buf.Bytes()
|
||||
if len(bs) <= int(lAddr) {
|
||||
// We use <= instead of < here as we expect at least one byte of data after the address
|
||||
return nil, errors.New("invalid message length")
|
||||
}
|
||||
m.Addr = string(bs[:lAddr])
|
||||
m.Data = bs[lAddr:]
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// varintPut is like quicvarint.Append, but instead of appending to a slice,
|
||||
// it writes to a fixed-size buffer. Returns the number of bytes written.
|
||||
func varintPut(b []byte, i uint64) int {
|
||||
if i <= maxVarInt1 {
|
||||
b[0] = uint8(i)
|
||||
return 1
|
||||
}
|
||||
if i <= maxVarInt2 {
|
||||
b[0] = uint8(i>>8) | 0x40
|
||||
b[1] = uint8(i)
|
||||
return 2
|
||||
}
|
||||
if i <= maxVarInt4 {
|
||||
b[0] = uint8(i>>24) | 0x80
|
||||
b[1] = uint8(i >> 16)
|
||||
b[2] = uint8(i >> 8)
|
||||
b[3] = uint8(i)
|
||||
return 4
|
||||
}
|
||||
if i <= maxVarInt8 {
|
||||
b[0] = uint8(i>>56) | 0xc0
|
||||
b[1] = uint8(i >> 48)
|
||||
b[2] = uint8(i >> 40)
|
||||
b[3] = uint8(i >> 32)
|
||||
b[4] = uint8(i >> 24)
|
||||
b[5] = uint8(i >> 16)
|
||||
b[6] = uint8(i >> 8)
|
||||
b[7] = uint8(i)
|
||||
return 8
|
||||
}
|
||||
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
|
||||
}
|
||||
Reference in New Issue
Block a user