diff --git a/README.md b/README.md index 7a61892f..5b54a6e0 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Based on TON][ton-svg]][ton] [![Telegram Channel][tgc-svg]][tg-channel] -![Coverage](https://img.shields.io/badge/Coverage-69.7%25-yellow) +![Coverage](https://img.shields.io/badge/Coverage-70.5%25-brightgreen) Golang library for interacting with TON blockchain. diff --git a/adnl/adnl.go b/adnl/adnl.go index 6825f9bd..80287964 100644 --- a/adnl/adnl.go +++ b/adnl/adnl.go @@ -71,6 +71,8 @@ type ADNL struct { recvPriorityAddrVer int32 ourAddrVerOnPeerSide int32 + peerID []byte + sharedKey []byte peerKey ed25519.PublicKey ourAddresses unsafe.Pointer @@ -175,6 +177,16 @@ func (a *ADNL) processPacket(packet *PacketContent, fromChannel bool) (err error if !fromChannel && packet.From != nil { a.mx.Lock() if a.peerKey == nil { + a.sharedKey, err = keys.SharedKey(a.ourKey, packet.From.Key) + if err != nil { + return err + } + + a.peerID, err = tl.Hash(keys.PublicKeyED25519{Key: packet.From.Key}) + if err != nil { + return err + } + a.peerKey = packet.From.Key } a.mx.Unlock() @@ -735,12 +747,11 @@ func (a *ADNL) GetAddressList() address.List { } func (a *ADNL) GetID() []byte { - id, _ := tl.Hash(keys.PublicKeyED25519{Key: a.peerKey}) - return id + return append([]byte{}, a.peerID...) } func (a *ADNL) GetPubKey() ed25519.PublicKey { - return a.peerKey + return append(ed25519.PublicKey{}, a.peerKey...) } func (a *ADNL) Reinit() { @@ -825,23 +836,14 @@ func (a *ADNL) createPacket(seqno int64, isResp bool, msgs ...any) ([]byte, erro hash := sha256.Sum256(packetData) checksum := hash[:] - key, err := keys.SharedKey(a.ourKey, a.peerKey) - if err != nil { - return nil, err - } - - ctr, err := keys.BuildSharedCipher(key, checksum) + ctr, err := keys.BuildSharedCipher(a.sharedKey, checksum) if err != nil { return nil, err } ctr.XORKeyStream(packetData, packetData) - enc, err := tl.Hash(keys.PublicKeyED25519{Key: a.peerKey}) - if err != nil { - return nil, err - } - copy(bufData, enc) + copy(bufData, a.peerID) copy(bufData[32:], a.ourKey.Public().(ed25519.PublicKey)) copy(bufData[64:], checksum) diff --git a/adnl/dht/client.go b/adnl/dht/client.go index 449c5a64..88866e74 100644 --- a/adnl/dht/client.go +++ b/adnl/dht/client.go @@ -450,6 +450,7 @@ func (c *Client) FindValue(ctx context.Context, key *Key, continuation ...*Conti cond := sync.NewCond(&sync.Mutex{}) waitingThreads := 0 + stopped := false launchWorker := func() { for { @@ -465,12 +466,18 @@ func (c *Client) FindValue(ctx context.Context, key *Key, continuation ...*Conti for node == nil { waitingThreads++ if waitingThreads == threads { + stopped = true + cond.Broadcast() cond.L.Unlock() result <- nil return } cond.Wait() + if stopped { + cond.L.Unlock() + return + } node, _ = plist.Get() waitingThreads-- } @@ -485,20 +492,27 @@ func (c *Client) FindValue(ctx context.Context, key *Key, continuation ...*Conti switch v := val.(type) { case *Value: + cond.L.Lock() + if !stopped { + stopped = true + cond.Broadcast() + } + cond.L.Unlock() result <- &foundResult{value: v, node: node} return case []*Node: added := false + cond.L.Lock() for _, n := range v { if newNode, err := c.addNode(n); err == nil { plist.Add(newNode) added = true } } - if added { cond.Broadcast() } + cond.L.Unlock() } } } @@ -509,8 +523,20 @@ func (c *Client) FindValue(ctx context.Context, key *Key, continuation ...*Conti select { case <-ctx.Done(): + cond.L.Lock() + if !stopped { + stopped = true + cond.Broadcast() + } + cond.L.Unlock() return nil, nil, ctx.Err() case val := <-result: + cond.L.Lock() + if !stopped { + stopped = true + cond.Broadcast() + } + cond.L.Unlock() if val == nil { return nil, cont, ErrDHTValueIsNotFound } diff --git a/adnl/gateway.go b/adnl/gateway.go index c07a0320..fe72435d 100644 --- a/adnl/gateway.go +++ b/adnl/gateway.go @@ -431,8 +431,23 @@ func (g *Gateway) registerClient(addr net.Addr, key ed25519.PublicKey, id string addrList.Version = addrList.ReinitDate a := g.initADNL() - a.SetAddresses(addrList) + + sharedKey, err := keys.SharedKey(a.ourKey, key) + if err != nil { + return nil, err + } + + peerId, err := tl.Hash(keys.PublicKeyED25519{Key: key}) + if err != nil { + return nil, err + } + + a.peerID = peerId + a.sharedKey = sharedKey a.peerKey = key + + a.SetAddresses(addrList) + a.addr = addr.String() a.writer = newWriter(func(p []byte, deadline time.Time) (err error) { currentAddr := *(*net.Addr)(atomic.LoadPointer(&peer.addr)) diff --git a/adnl/packet.go b/adnl/packet.go index 64b2a45b..ed93d643 100644 --- a/adnl/packet.go +++ b/adnl/packet.go @@ -184,7 +184,7 @@ func (p *PacketContent) Serialize(buf *bytes.Buffer) (int, error) { binary.LittleEndian.PutUint32(tmp, _PacketContentID) buf.Write(tmp) - tl.ToBytesToBuffer(buf, p.Rand1) + _ = tl.ToBytesToBuffer(buf, p.Rand1) var flags uint32 if p.Seqno != nil { @@ -314,10 +314,10 @@ func (p *PacketContent) Serialize(buf *bytes.Buffer) (int, error) { } if p.Signature != nil { - tl.ToBytesToBuffer(buf, p.Signature) + _ = tl.ToBytesToBuffer(buf, p.Signature) } - tl.ToBytesToBuffer(buf, p.Rand2) + _ = tl.ToBytesToBuffer(buf, p.Rand2) return payloadLen, nil } diff --git a/adnl/rldp/bbr2.go b/adnl/rldp/bbr2.go new file mode 100644 index 00000000..b780d235 --- /dev/null +++ b/adnl/rldp/bbr2.go @@ -0,0 +1,571 @@ +package rldp + +import ( + "fmt" + "math" + "sync/atomic" + "time" +) + +var BBRLogger func(a ...any) = nil + +type SendClock struct { + mask uint32 + startedAt int64 + slots []atomic.Uint64 // packed: [seqno:32][t_ms:32] +} + +func NewSendClock(capPow2 int) *SendClock { + if capPow2&(capPow2-1) != 0 { + panic("cap must be power of two") + } + + s := &SendClock{ + startedAt: time.Now().UnixMilli(), + mask: uint32(capPow2 - 1), + slots: make([]atomic.Uint64, capPow2), + } + return s +} + +func pack(seq, ms uint32) uint64 { return (uint64(seq) << 32) | uint64(ms) } +func unpack(v uint64) (seq, ms uint32) { return uint32(v >> 32), uint32(v) } + +func (s *SendClock) OnSend(seq uint32, nowMs int64) { + idx := seq & s.mask + s.slots[idx].Store(pack(seq, uint32(nowMs-s.startedAt))) +} + +func (s *SendClock) SentAt(seq uint32) (ms int64, ok bool) { + idx := seq & s.mask + v := s.slots[idx].Load() + if hi, lo := unpack(v); hi == seq { + return int64(lo) + s.startedAt, true + } + return 0, false +} + +type BBRv2Options struct { + // Time window for bottleneck bandwidth estimation (seconds) + BtlBwWindowSec int + + // Minimum duration of a gain cycle in ProbeBW (ms) + ProbeBwCycleMs int64 + + // Duration of ProbeRTT phase (ms) + ProbeRTTDurationMs int64 + + // MinRTT staleness timeout (ms): enter ProbeRTT if minRTT hasn't been refreshed longer than this + MinRTTExpiryMs int64 + + // Lower and upper bounds for pacing (bytes/sec) + MinRate int64 + MaxRate int64 // 0 = no cap + + // Threshold for "high loss" (fraction) + HighLoss float64 // e.g., 0.02..0.1 + + // Beta factor to shrink inflight_hi when losses are high + Beta float64 // e.g., 0.85 + + // Initial "guessed" RTT if ObserveRTT is unavailable + DefaultRTTMs int64 + + // Minimum ACK window duration (ms) to avoid updating too frequently + MinSampleMs int64 + + Name string +} + +type BBRv2Controller struct { + limiter *TokenBucket + opts BBRv2Options + + // Accumulators for input deltas + _total atomic.Int64 + _recv atomic.Int64 + _samples atomic.Int64 + lastProc atomic.Int64 // unix ms of the last update + + // BBR state + state atomic.Int32 // 0=startup, 1=drain, 2=probebw, 3=probertt + cycleStamp atomic.Int64 // start time of the current gain cycle + cycleIndex atomic.Int32 // index within the gain table + fullBW atomic.Int64 // "full bandwidth" detection + fullBWCount atomic.Int32 + + // Filters and estimates + btlbw atomic.Int64 // bytes/sec (max filter) + minRTT atomic.Int64 // ms + lastRTT atomic.Int64 // ms + minRTTAt atomic.Int64 // unix ms when minRTT was last updated + minRTTProvisional atomic.Bool + inflight atomic.Int64 // target inflight (bytes), roughly BtlBw * minRTT + hiInflight atomic.Int64 + loInflight atomic.Int64 + + // Loss accounting for the current window + lossTotal atomic.Int64 + lossLost atomic.Int64 + lastAckTs atomic.Int64 // unix ms marking the start of the ACK window + + // Current pacing rate (bytes/sec) + pacingRate atomic.Int64 + + appLimited atomic.Bool + + dbgLast atomic.Int64 + + lastBtlBwDecay atomic.Int64 + + lastLossRate atomic.Uint64 + lastSampleTot atomic.Int64 + lastSampleLos atomic.Int64 +} + +func NewBBRv2Controller(l *TokenBucket, o BBRv2Options) *BBRv2Controller { + applyBBRDefaults(&o) + now := nowMs() + c := &BBRv2Controller{ + limiter: l, + opts: o, + } + c.state.Store(0) + c.cycleStamp.Store(now) + c.lastProc.Store(now) + c.lastAckTs.Store(now) + c.lastBtlBwDecay.Store(now) + + if o.MinRate > 0 { + c.pacingRate.Store(o.MinRate) + l.SetRate(o.MinRate) + } + + if o.DefaultRTTMs > 0 { + c.minRTT.Store(o.DefaultRTTMs) + c.minRTTAt.Store(now) + } else { + c.minRTT.Store(25) + c.minRTTAt.Store(now) + } + c.lastRTT.Store(c.minRTT.Load()) + c.minRTTProvisional.Store(true) + + start := l.GetRate() + if start <= 0 { + start = max64(o.MinRate, 1024*64) + } + c.btlbw.Store(start) + c.pacingRate.Store(start) + c.inflight.Store(rateToInflight(start, c.minRTT.Load())) + c.hiInflight.Store(c.inflight.Load()) + c.loInflight.Store(0) + + return c +} + +func applyBBRDefaults(o *BBRv2Options) { + if o.BtlBwWindowSec == 0 { + o.BtlBwWindowSec = 10 + } + if o.ProbeBwCycleMs == 0 { + o.ProbeBwCycleMs = 200 + } + if o.ProbeRTTDurationMs == 0 { + o.ProbeRTTDurationMs = 150 + } + if o.MinRTTExpiryMs == 0 { + o.MinRTTExpiryMs = 10_000 // 10s + } + if o.MinRate == 0 { + o.MinRate = 32 * 1024 // 32 KiB/s + } + if o.HighLoss == 0 { + o.HighLoss = 0.05 // 5% + } + if o.Beta == 0 { + o.Beta = 0.85 + } + if o.DefaultRTTMs == 0 { + o.DefaultRTTMs = 25 + } + if o.MinSampleMs == 0 { + o.MinSampleMs = 25 + } +} + +func (c *BBRv2Controller) SetAppLimited(v bool) { c.appLimited.Store(v) } + +func (c *BBRv2Controller) ObserveDelta(total, recv int64) { + if total == 0 { + return + } + + c._total.Add(total) + c._recv.Add(recv) + c._samples.Add(1) + c.maybeUpdate() +} + +func (c *BBRv2Controller) ObserveRTT(rttMs int64) { + now := nowMs() + old := c.minRTT.Load() + provisional := c.minRTTProvisional.Load() + + if old == 0 || provisional || rttMs < old { + c.minRTT.Store(rttMs) + c.minRTTAt.Store(now) + c.minRTTProvisional.Store(false) + } else if rttMs <= old+max64(1, old/8) { // <= 12.5% from min + c.minRTTAt.Store(now) + } + c.lastRTT.Store(rttMs) + + if btl := c.btlbw.Load(); btl > 0 { + c.inflight.Store(rateToInflight(btl, c.minRTT.Load())) + } +} + +func (c *BBRv2Controller) maybeUpdate() { + now := nowMs() + + last := c.lastProc.Load() + if last+c.opts.MinSampleMs > now { + return + } + if !c.lastProc.CompareAndSwap(last, now) { + return + } + + prevAckTs := c.lastAckTs.Swap(now) + elapsedMs := now - prevAckTs + if elapsedMs < max64(10, c.opts.MinSampleMs/2) { + return + } + + total := c._total.Swap(0) + acked := c._recv.Swap(0) + c._samples.Store(0) + if total <= 0 { + return + } + + lost := total - acked + if lost < 0 { + lost = 0 + } + + c.lastSampleTot.Store(total) + c.lastSampleLos.Store(lost) + + const minAckBytesForLoss = 2 * 1500 // min ~2 MSS confirmed + if total >= minAckBytesForLoss { + c.lossTotal.Add(total) + c.lossLost.Add(lost) + } + + if acked > 0 { + ackRate := int64(float64(acked) * 1000.0 / float64(elapsedMs)) // B/s + c.updateBtlBw(ackRate, now) + } + + c.checkProbeRTT(now, acked) + lossRate := c.updateModelAndRate(now) + + if BBRLogger != nil && now-c.dbgLast.Load() >= 1000 { + c.dbgLast.Store(now) + + var ackRateBps int64 + if elapsedMs > 0 { + ackRateBps = int64(float64(acked) * 1000.0 / float64(elapsedMs)) + } + lossPct := fmt.Sprintf("%.2f%%", lossRate*100.0) + + BBRLogger("[BBR] ", + c.opts.Name, " win elapsed=", elapsedMs, "ms acked="+humanBytes(acked)+" total="+humanBytes(total)+" loss=", lossPct, + "state=", c.state.Load(), "appLimited=", c.appLimited.Load(), "ackRate="+humanBps(ackRateBps)+" pacing="+humanBps(c.pacingRate.Load())+ + " btlbw="+humanBps(c.btlbw.Load())+" minRTT=", c.minRTT.Load(), "ms", + ) + } +} + +func humanBps(bps int64) string { + if bps <= 0 { + return "0 B/s (0 Mbit/s)" + } + miBps := float64(bps) / (1024.0 * 1024.0) // MiB/s + mbps := float64(bps*8) / 1e6 + return fmt.Sprintf("%.2f MB/s (%.2f Mbit/s)", miBps, mbps) +} + +func humanBytes(n int64) string { + const ( + KiB = 1024 + MiB = 1024 * KiB + GiB = 1024 * MiB + ) + switch { + case n >= GiB: + return fmt.Sprintf("%.2f GB", float64(n)/float64(GiB)) + case n >= MiB: + return fmt.Sprintf("%.2f MB", float64(n)/float64(MiB)) + case n >= KiB: + return fmt.Sprintf("%.2f KB", float64(n)/float64(KiB)) + default: + return fmt.Sprintf("%d B", n) + } +} + +func (c *BBRv2Controller) updateBtlBw(sample int64, now int64) { + if sample <= 0 { + return + } + + if !c.appLimited.Load() { + cur := c.btlbw.Load() + if sample > cur { + c.btlbw.Store(sample) + } + + // full bandwidth reached + if cur > 0 { + if float64(sample) < float64(cur)*1.25 { + if c.fullBWCount.Add(1) >= 3 && c.fullBW.Load() == 0 { + c.fullBW.Store(cur) + } + } else { + c.fullBWCount.Store(0) + c.fullBW.Store(0) + c.btlbw.Store(max64(cur, sample)) + } + } + } + + // Soft decay of an overly old max (emulates a time window) + // Every BtlBwWindowSec seconds decrease by 10% if no larger samples arrived + winMs := int64(c.opts.BtlBwWindowSec * 1000) + lastDecay := c.lastBtlBwDecay.Load() + if lastDecay == 0 { + lastDecay = now + } + if lastDecay+winMs < now && c.lastBtlBwDecay.CompareAndSwap(lastDecay, now) { + decayed := int64(float64(c.btlbw.Load()) * 0.9) + if decayed < c.opts.MinRate { + decayed = c.opts.MinRate + } + c.btlbw.Store(decayed) + } +} + +func (c *BBRv2Controller) InflightAllowance(currentBytes int64) int64 { + if currentBytes <= 0 { + currentBytes = 0 + } + + hi := c.hiInflight.Load() + if hi <= 0 { + hi = c.inflight.Load() + } + if hi <= 0 { + minRtt := c.minRTT.Load() + if minRtt <= 0 { + minRtt = c.opts.DefaultRTTMs + } + pacing := c.pacingRate.Load() + if pacing <= 0 { + pacing = c.opts.MinRate + } + hi = rateToInflight(pacing, minRtt) + } + + allowance := hi - currentBytes + if allowance <= 0 { + return 0 + } + return allowance +} + +func (c *BBRv2Controller) CurrentMinRTT() int64 { + return c.minRTT.Load() +} + +func (c *BBRv2Controller) CurrentRTT() int64 { + return c.lastRTT.Load() +} + +func (c *BBRv2Controller) checkProbeRTT(now int64, ackedBytes int64) { + if c.state.Load() != 3 && now-c.minRTTAt.Load() > c.opts.MinRTTExpiryMs && + !c.appLimited.Load() && ackedBytes > 0 { + + c.state.Store(3) + c.cycleStamp.Store(now) + } + + if c.state.Load() == 3 && now-c.cycleStamp.Load() >= c.opts.ProbeRTTDurationMs { + c.state.Store(2) + c.cycleStamp.Store(now) + c.cycleIndex.Store(0) + } +} + +func (c *BBRv2Controller) updateModelAndRate(now int64) float64 { + state := c.state.Load() + bw := c.btlbw.Load() + if bw <= 0 { + bw = c.opts.MinRate + } + + // Update inflight target = bw * minRTT + inflight := rateToInflight(bw, c.minRTT.Load()) + if inflight <= 0 { + inflight = 2 * 1500 // at least two MSS-equivalents + } + c.inflight.Store(inflight) + + // Losses in the last window → decide whether to lower inflight_hi + var lossRate float64 + lt := c.lossTotal.Swap(0) + ll := c.lossLost.Swap(0) + if lt > 0 { + lossRate = float64(ll) / float64(lt) + } + + c.lastLossRate.Store(math.Float64bits(lossRate)) + + // BBRv2: if loss is high — tighten the upper bound inflight_hi BELOW the model + hi := c.hiInflight.Load() + if hi == 0 { + hi = inflight + } + if lossRate >= c.opts.HighLoss { + // multiplicative decrease like BBRv2 + newHi := int64(float64(hi) * c.opts.Beta) + // allow going below the model to drain the queue, but keep a sane floor + floor := max64(2*1500, inflight/2) // >=2*MSS and not below ~0.5*model + if newHi < floor { + newHi = floor + } + c.hiInflight.Store(newHi) + } else { + // Slowly relax upward + relax := hi + max64(inflight/16, 1500) // +~6% or at least one MSS + c.hiInflight.Store(min64(relax, inflight*4)) + } + + // Choose pacing_gain by state + var pacingGain = 1.0 + switch state { + case 0: // Startup + pacingGain = 2.885 // classic BBR startup + // Transition to Drain once "full bandwidth" is reached + if c.fullBW.Load() > 0 { + c.state.Store(1) + c.cycleStamp.Store(now) + } + case 1: // Drain + pacingGain = 1.0 / 2.885 + // Finish drain relatively quickly + if now-c.cycleStamp.Load() >= 200 { + c.state.Store(2) // ProbeBW + c.cycleStamp.Store(now) + c.cycleIndex.Store(0) + } + case 2: // ProbeBW + // Moderate BBRv2 gain cycle: {1.25, 0.75, 1,1,1,1,1,1} + gains := [...]float64{1.25, 0.75, 1, 1, 1, 1, 1, 1} + idx := int(c.cycleIndex.Load()) + if idx < 0 || idx >= len(gains) { + idx = 0 + c.cycleIndex.Store(0) + } + pacingGain = gains[idx] + // Advance the cycle + if now-c.cycleStamp.Load() >= c.opts.ProbeBwCycleMs { + c.cycleStamp.Store(now) + c.cycleIndex.Store(int32((idx + 1) % len(gains))) + } + case 3: // ProbeRTT + pacingGain = 0.5 // send less to probe RTT + } + + // Map inflight_hi into a rate cap (upper bound) + // targetRate = min(bw * pacingGain, hiInflight / minRTT) + targetByGain := float64(bw) * pacingGain + minRtt := max64(c.minRTT.Load(), 1) + hiBytesPerSec := float64(c.hiInflight.Load()) * 1000.0 / float64(minRtt) + target := min64(int64(targetByGain), int64(hiBytesPerSec)) + + prev := c.pacingRate.Load() + + if lossRate >= c.opts.HighLoss { + lossCap := int64(float64(prev) * c.opts.Beta) + if lossCap < c.opts.MinRate { + lossCap = c.opts.MinRate + } + if target > lossCap { + target = lossCap + } + } + + // Lower/upper bounds + if target < c.opts.MinRate { + target = c.opts.MinRate + } + if c.opts.MaxRate > 0 && target > c.opts.MaxRate { + target = c.opts.MaxRate + } + + // Smoothing: limit step changes up/down (except during Startup/ProbeRTT) + maxUp := int64(float64(prev) * 1.5) + maxDown := int64(float64(prev) * 0.7) + if state != 0 && state != 3 { // don't limit in Startup/ProbeRTT + if target > maxUp { + target = maxUp + } + if target < maxDown { + target = maxDown + } + } + + if target <= 0 { + target = c.opts.MinRate + } + + if target != prev { + c.pacingRate.Store(target) + c.limiter.SetRate(target) + } + return lossRate +} + +func (c *BBRv2Controller) LastLossSample() (total, lost int64, rate float64) { + total = c.lastSampleTot.Load() + lost = c.lastSampleLos.Load() + rate = math.Float64frombits(c.lastLossRate.Load()) + return +} + +func rateToInflight(rateBytesPerSec int64, rttMs int64) int64 { + if rateBytesPerSec <= 0 { + return 0 + } + if rttMs <= 0 { + rttMs = 1 + } + return int64(float64(rateBytesPerSec) * float64(rttMs) / 1000.0) +} + +func nowMs() int64 { return time.Now().UnixMilli() } + +func max64(a, b int64) int64 { + if a > b { + return a + } + return b +} +func min64(a, b int64) int64 { + if a < b { + return a + } + return b +} diff --git a/adnl/rldp/bbr2_test.go b/adnl/rldp/bbr2_test.go new file mode 100644 index 00000000..9d26efa6 --- /dev/null +++ b/adnl/rldp/bbr2_test.go @@ -0,0 +1,360 @@ +package rldp + +import ( + "math" + "sync" + "testing" + "time" +) + +func waitUntil(t *testing.T, timeout time.Duration, cond func() bool, msg string) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(2 * time.Millisecond) + } + t.Fatalf("timeout: %s", msg) +} + +func newBBR(t *testing.T, initRate int64, opts BBRv2Options) (*BBRv2Controller, *TokenBucket) { + t.Helper() + tb := NewTokenBucket(initRate, "test-peer") + return NewBBRv2Controller(tb, opts), tb +} + +func TestBBR_StartupIncreasesRate(t *testing.T) { + opts := BBRv2Options{ + MinRate: 20_000, + MaxRate: 0, + DefaultRTTMs: 20, + MinSampleMs: 10, + BtlBwWindowSec: 2, + ProbeBwCycleMs: 50, + ProbeRTTDurationMs: 40, + MinRTTExpiryMs: 5_000, + HighLoss: 0.05, + Beta: 0.85, + } + bbr, tb := newBBR(t, opts.MinRate, opts) + + for i := 0; i < 60; i++ { + bbr.ObserveDelta(100_000, 100_000) + time.Sleep(12 * time.Millisecond) + } + + waitUntil(t, 2*time.Second, func() bool { + return bbr.pacingRate.Load() > opts.MinRate + }, "pacingRate should increase in Startup") + + if tb.GetRate() != bbr.pacingRate.Load() { + t.Fatalf("limiter rate mismatch: tb=%d bbr=%d", tb.GetRate(), bbr.pacingRate.Load()) + } +} + +func TestBBR_HighLossReducesRate(t *testing.T) { + opts := BBRv2Options{ + MinRate: 50_000, + DefaultRTTMs: 25, + MinSampleMs: 10, + BtlBwWindowSec: 2, + ProbeBwCycleMs: 50, + ProbeRTTDurationMs: 50, + MinRTTExpiryMs: 5_000, + HighLoss: 0.05, + Beta: 0.85, + } + bbr, _ := newBBR(t, opts.MinRate, opts) + + for i := 0; i < 12; i++ { + bbr.ObserveDelta(200_000, 200_000) + time.Sleep(12 * time.Millisecond) + } + r1 := bbr.pacingRate.Load() + + for i := 0; i < 8; i++ { + bbr.ObserveDelta(100_000, 80_000) + time.Sleep(12 * time.Millisecond) + } + r2 := bbr.pacingRate.Load() + if r2 >= r1 { + t.Fatalf("expected rate drop on high loss: before=%d after=%d", r1, r2) + } +} + +func TestBBR_ProbeRTT_EnterAndExit(t *testing.T) { + opts := BBRv2Options{ + MinRate: 30_000, + DefaultRTTMs: 10, + MinSampleMs: 10, + BtlBwWindowSec: 2, + ProbeBwCycleMs: 40, + ProbeRTTDurationMs: 30, + MinRTTExpiryMs: 60, + HighLoss: 0.1, + Beta: 0.9, + } + bbr, _ := newBBR(t, opts.MinRate, opts) + + // немного трафика + for i := 0; i < 5; i++ { + bbr.ObserveDelta(50_000, 50_000) + time.Sleep(12 * time.Millisecond) + } + + time.Sleep(70 * time.Millisecond) + for i := 0; i < 10 && bbr.state.Load() != 3; i++ { + bbr.ObserveDelta(1_000, 1_000) + time.Sleep(12 * time.Millisecond) + } + if bbr.state.Load() != 3 { + t.Fatalf("should enter ProbeRTT") + } + + for i := 0; i < 20 && bbr.state.Load() != 2; i++ { + bbr.ObserveDelta(1_000, 1_000) + time.Sleep(12 * time.Millisecond) + } + if bbr.state.Load() != 2 { + t.Fatalf("should exit to ProbeBW") + } +} + +func TestBBR_RespectsMinMaxRate(t *testing.T) { + opts := BBRv2Options{ + MinRate: 10_000, + MaxRate: 15_000, + DefaultRTTMs: 10, + MinSampleMs: 10, + ProbeBwCycleMs: 40, + ProbeRTTDurationMs: 40, + MinRTTExpiryMs: 5_000, + HighLoss: 0.05, + Beta: 0.85, + } + bbr, _ := newBBR(t, opts.MinRate, opts) + + for i := 0; i < 25; i++ { + bbr.ObserveDelta(1_000_000, 1_000_000) + time.Sleep(12 * time.Millisecond) + } + if got := bbr.pacingRate.Load(); got > opts.MaxRate { + t.Fatalf("rate exceeded MaxRate: got=%d max=%d", got, opts.MaxRate) + } + + for i := 0; i < 5; i++ { + bbr.ObserveDelta(100_000, 20_000) // 80% loss + time.Sleep(12 * time.Millisecond) + } + if got := bbr.pacingRate.Load(); got < opts.MinRate { + t.Fatalf("rate fell below MinRate: got=%d min=%d", got, opts.MinRate) + } +} + +func TestBBR_SmoothingBounds(t *testing.T) { + opts := BBRv2Options{ + MinRate: 80_000, + DefaultRTTMs: 20, + MinSampleMs: 10, + BtlBwWindowSec: 2, + ProbeBwCycleMs: 50, + ProbeRTTDurationMs: 50, + MinRTTExpiryMs: 5_000, + HighLoss: 0.05, + Beta: 0.85, + } + bbr, _ := newBBR(t, opts.MinRate, opts) + + for i := 0; i < 10; i++ { + bbr.ObserveDelta(500_000, 500_000) + time.Sleep(12 * time.Millisecond) + } + prev := bbr.pacingRate.Load() + + bbr.ObserveDelta(10_000_000, 10_000_000) + time.Sleep(14 * time.Millisecond) + now := bbr.pacingRate.Load() + if now > int64(float64(prev)*1.55) { + t.Fatalf("up-smoothing failed: prev=%d now=%d", prev, now) + } + + bbr.ObserveDelta(1_000_000, 100_000) // 90% loss + time.Sleep(14 * time.Millisecond) + after := bbr.pacingRate.Load() + if after < int64(float64(now)*0.65) { + t.Fatalf("down-smoothing failed: now=%d after=%d", now, after) + } +} + +func TestBBR_BtlBwDecay(t *testing.T) { + opts := BBRv2Options{ + MinRate: 10_000, + DefaultRTTMs: 20, + MinSampleMs: 10, + BtlBwWindowSec: 1, + ProbeBwCycleMs: 100, + ProbeRTTDurationMs: 100, + MinRTTExpiryMs: 10_000, + HighLoss: 0.2, + Beta: 0.85, + } + bbr, _ := newBBR(t, opts.MinRate, opts) + + for i := 0; i < 6; i++ { + bbr.ObserveDelta(200_000, 200_000) + time.Sleep(12 * time.Millisecond) + } + peak := bbr.btlbw.Load() + if peak <= opts.MinRate { + t.Fatalf("unexpected peak btlbw: %d", peak) + } + + winMs := int64(opts.BtlBwWindowSec * 1000) + bbr.lastBtlBwDecay.Store(time.Now().UnixMilli() - winMs - 10) + bbr.ObserveDelta(1, 1) + time.Sleep(12 * time.Millisecond) + + decayed := bbr.btlbw.Load() + if !(decayed < peak && decayed >= opts.MinRate) { + t.Fatalf("expected btlbw decay: before=%d after=%d (min=%d)", peak, decayed, opts.MinRate) + } +} + +func TestBBR_LossSampleTracked(t *testing.T) { + opts := BBRv2Options{ + MinRate: 60_000, + DefaultRTTMs: 20, + MinSampleMs: 10, + BtlBwWindowSec: 2, + ProbeBwCycleMs: 40, + ProbeRTTDurationMs: 40, + MinRTTExpiryMs: 5_000, + HighLoss: 0.05, + Beta: 0.85, + } + bbr, _ := newBBR(t, opts.MinRate, opts) + + time.Sleep(12 * time.Millisecond) + bbr.ObserveDelta(400_000, 200_000) + + total, lost, rate := bbr.LastLossSample() + if total <= 0 || lost <= 0 { + t.Fatalf("expected non-zero loss sample, got total=%d lost=%d", total, lost) + } + + expected := float64(lost) / float64(total) + if math.Abs(rate-expected) > 1e-3 { + t.Fatalf("loss rate mismatch: want %.4f got %.4f (total=%d lost=%d)", expected, rate, total, lost) + } +} + +func TestBBR_InflightAllowance(t *testing.T) { + opts := BBRv2Options{ + MinRate: 40_000, + DefaultRTTMs: 20, + MinSampleMs: 10, + } + bbr, _ := newBBR(t, opts.MinRate, opts) + + bbr.hiInflight.Store(60_000) + if got := bbr.InflightAllowance(30_000); got != 30_000 { + t.Fatalf("unexpected allowance: %d", got) + } + + if got := bbr.InflightAllowance(80_000); got != 0 { + t.Fatalf("allowance should clamp at zero, got %d", got) + } + + bbr.hiInflight.Store(0) + bbr.inflight.Store(45_000) + if got := bbr.InflightAllowance(5_000); got != 40_000 { + t.Fatalf("fallback allowance mismatch: %d", got) + } +} + +func approxI64(a, b, tol int64) bool { + d := a - b + if d < 0 { + d = -d + } + return d <= tol +} + +func TestSendClock_OnSendAndSentAt(t *testing.T) { + sc := NewSendClock(1024) // power of two + seq := uint32(42) + + now := time.Now().UnixMilli() + sc.OnSend(seq, now) + + got, ok := sc.SentAt(seq) + if !ok { + t.Fatalf("SentAt(%d) ok=false, want true", seq) + } + // допускаем небольшую разницу (квант времени) + if !approxI64(got, now, 3) { + t.Fatalf("SentAt ms mismatch: got=%d want~=%d", got, now) + } +} + +func TestSendClock_CollisionOverwritesOld(t *testing.T) { + sc := NewSendClock(8) // mask=7 + base := time.Now().UnixMilli() + + seq1 := uint32(10) // 10 & 7 = 2 + seq2 := uint32(18) // 18 & 7 = 2 + + sc.OnSend(seq1, base+1) + sc.OnSend(seq2, base+2) + + if _, ok := sc.SentAt(seq1); ok { + t.Fatalf("expected collision overwrite for seq1=%d", seq1) + } + got2, ok2 := sc.SentAt(seq2) + if !ok2 || got2 != base+2 { + t.Fatalf("seq2 not found or bad ts: ok=%v got=%d want=%d", ok2, got2, base+2) + } +} + +func TestSendClock_RecentWindowVisibleAfterWrap(t *testing.T) { + capPow2 := 16 + sc := NewSendClock(capPow2) + start := time.Now().UnixMilli() + + for i := 0; i < 1000; i++ { + sc.OnSend(uint32(i), start+int64(i)) + } + for i := 1000 - capPow2; i < 1000; i++ { + if _, ok := sc.SentAt(uint32(i)); !ok { + t.Fatalf("recent seq=%d not found", i) + } + } +} + +func TestSendClock_ConcurrentRaces(t *testing.T) { + sc := NewSendClock(4096) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + base := time.Now().UnixMilli() + for i := uint32(1); i < 50000; i++ { + sc.OnSend(i, base+int64(i%5000)) + } + }() + + for r := 0; r < 4; r++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := uint32(1); i < 50000; i++ { + sc.SentAt(i) + } + }() + } + + wg.Wait() +} diff --git a/adnl/rldp/bucket.go b/adnl/rldp/bucket.go index f3e00cdd..61a7b9b0 100644 --- a/adnl/rldp/bucket.go +++ b/adnl/rldp/bucket.go @@ -5,36 +5,57 @@ import ( "time" ) +// TokenBucket bytes/sec type TokenBucket struct { ratePerSec int64 capacity int64 - tokens int64 - lastRefill int64 + + lastRefill int64 // UnixMicro peerName string } -func NewTokenBucket(rate int64, peerName string) *TokenBucket { +// NewTokenBucket create bucket with bytes/sec. +func NewTokenBucket(bps int64, peerName string) *TokenBucket { + if bps < 1 { + bps = 1 + } + x := bps * 1000 // burst 1 sec + return &TokenBucket{ - ratePerSec: rate * 1000, - capacity: rate * 1000, - tokens: rate * 1000, + ratePerSec: x, + capacity: x, + tokens: x, lastRefill: time.Now().UnixMicro(), peerName: peerName, } } -func (tb *TokenBucket) SetRate(pps int64) { - if pps < 128 { - pps = 128 - } else if pps > 5000000 { - pps = 5000000 +func (tb *TokenBucket) SetCapacityBytes(burstBytes int64) { + if burstBytes < 0 { + burstBytes = 0 + } + atomic.StoreInt64(&tb.capacity, burstBytes*1000) +} + +func (tb *TokenBucket) SetRate(bps int64) { + if bps < 8<<10 { // 8KB/s + bps = 8 << 10 + } else if bps > 500<<20 { // 500 MB/s + bps = 500 << 20 } + atomic.StoreInt64(&tb.ratePerSec, bps*1000) - atomic.StoreInt64(&tb.ratePerSec, pps*1000) - atomic.StoreInt64(&tb.capacity, pps*1000) - Logger("[RLDP] Peer rate updated:", tb.peerName, pps) + curCap := atomic.LoadInt64(&tb.capacity) + curRate := atomic.LoadInt64(&tb.ratePerSec) + + // if cap ~= old rate, use new + if abs64(curCap-curRate) < curRate/64 { // ~1.5% + atomic.StoreInt64(&tb.capacity, curRate) + } + + Logger("[RLDP] Peer pacing updated (Bps):", tb.peerName, bps) } func (tb *TokenBucket) GetRate() int64 { @@ -45,7 +66,12 @@ func (tb *TokenBucket) GetTokensLeft() int64 { return atomic.LoadInt64(&tb.tokens) / 1000 } -func (tb *TokenBucket) TryConsume() bool { +func (tb *TokenBucket) ConsumeUpTo(maxBytes int) int { + if maxBytes <= 0 { + return 0 + } + req := int64(maxBytes) + for { now := time.Now().UnixMicro() last := atomic.LoadInt64(&tb.lastRefill) @@ -53,21 +79,58 @@ func (tb *TokenBucket) TryConsume() bool { if elapsed > 0 { add := (elapsed * atomic.LoadInt64(&tb.ratePerSec)) / 1_000_000 - newTokens := atomic.LoadInt64(&tb.tokens) + add - if capacity := atomic.LoadInt64(&tb.capacity); newTokens > capacity { - newTokens = capacity - } - if atomic.CompareAndSwapInt64(&tb.lastRefill, last, now) { - atomic.StoreInt64(&tb.tokens, newTokens) + if add > 0 && atomic.CompareAndSwapInt64(&tb.lastRefill, last, now) { + for { + curr := atomic.LoadInt64(&tb.tokens) + newTokens := curr + add + capacity := atomic.LoadInt64(&tb.capacity) + if newTokens > capacity { + newTokens = capacity + } + + if atomic.CompareAndSwapInt64(&tb.tokens, curr, newTokens) { + break + } + } } } - if currTokens := atomic.LoadInt64(&tb.tokens); currTokens >= 1000 { // micro-tokens - if !atomic.CompareAndSwapInt64(&tb.tokens, currTokens, currTokens-1000) { - continue - } - return true + currTokens := atomic.LoadInt64(&tb.tokens) + availableBytes := currTokens / 1000 + if availableBytes <= 0 { + return 0 + } + + toConsume := req + if availableBytes < toConsume { + toConsume = availableBytes } - return false + + micro := toConsume * 1000 + if atomic.CompareAndSwapInt64(&tb.tokens, currTokens, currTokens-micro) { + return int(toConsume) + } + + // race, repeat + } +} + +func (tb *TokenBucket) ConsumePackets(maxPackets, partSize int) int { + if maxPackets <= 0 || partSize <= 0 { + return 0 + } + wantBytes := int64(maxPackets) * int64(partSize) + gotBytes := tb.ConsumeUpTo(int(wantBytes)) + return gotBytes / partSize +} + +func (tb *TokenBucket) TryConsumeBytes(n int) bool { + return tb.ConsumeUpTo(n) == n +} + +func abs64(x int64) int64 { + if x < 0 { + return -x } + return x } diff --git a/adnl/rldp/bucket_test.go b/adnl/rldp/bucket_test.go new file mode 100644 index 00000000..41da5eb8 --- /dev/null +++ b/adnl/rldp/bucket_test.go @@ -0,0 +1,150 @@ +package rldp + +import ( + "sync" + "sync/atomic" + "testing" + "time" +) + +func wait(ms int) { time.Sleep(time.Duration(ms) * time.Millisecond) } + +func TestTokenBucket_Init(t *testing.T) { + tb := NewTokenBucket(100_000, "peer") // 100 KB/s + if got := tb.GetRate(); got != 100_000 { + t.Fatalf("GetRate=%d want=100000", got) + } + + if got := tb.GetTokensLeft(); got != 100_000 { + t.Fatalf("GetTokensLeft init=%d want=100000", got) + } +} + +func TestTokenBucket_Consume_And_Exhaust(t *testing.T) { + tb := NewTokenBucket(10_000, "peer") // 10 KB/s + if n := tb.ConsumeUpTo(4_000); n != 4_000 { + t.Fatalf("ConsumeUpTo 4k -> %d want 4000", n) + } + if left := tb.GetTokensLeft(); left != 6_000 { + t.Fatalf("left=%d want 6000", left) + } + + if n := tb.ConsumeUpTo(7_000); n != 6_000 { + t.Fatalf("ConsumeUpTo 7k -> %d want 6000 (cap by available)", n) + } + + if left := tb.GetTokensLeft(); left != 0 { + t.Fatalf("left after exhaust=%d want 0", left) + } + + if n := tb.ConsumeUpTo(1_000); n != 0 { + t.Fatalf("ConsumeUpTo when empty -> %d want 0", n) + } +} + +func TestTokenBucket_RefillOverTime(t *testing.T) { + tb := NewTokenBucket(10_000, "peer") // 10 KB/s + _ = tb.ConsumeUpTo(10_000) + if left := tb.GetTokensLeft(); left != 0 { + t.Fatalf("left=%d want 0", left) + } + + wait(120) + + n := tb.ConsumeUpTo(10_000) + if n < 800 || n > 2_000 { + t.Fatalf("refill bytes=%d want around 1200 (800..2000)", n) + } +} + +func TestTokenBucket_SetRate_DownAndUp(t *testing.T) { + tb := NewTokenBucket(40_000, "peer") // 40 KB/s + _ = tb.ConsumeUpTo(40_000) + tb.SetRate(10_000) + wait(110) + n := tb.ConsumeUpTo(10_000) + if n < 700 || n > 2_000 { + t.Fatalf("after downrate refill=%d want ~1100 (700..2000)", n) + } + + tb.SetRate(200_000) + _ = tb.ConsumeUpTo(200_000) + wait(100) // ~20_000 B + n2 := tb.ConsumeUpTo(1_000_000) + if n2 < 12_000 || n2 > 40_000 { + t.Fatalf("after uprate refill=%d want ~20000 (12000..40000)", n2) + } +} + +func TestTokenBucket_SetCapacityBytes_Burst(t *testing.T) { + tb := NewTokenBucket(50_000, "peer") + tb.SetCapacityBytes(10_000) + + _ = tb.ConsumeUpTo(1_000_000) + wait(250) + got := tb.ConsumeUpTo(1_000_000) + if got < 9_000 || got > 10_000 { + t.Fatalf("burst-capped consume=%d want ~10k (9000..10000)", got) + } +} + +func TestTokenBucket_ConsumePackets(t *testing.T) { + tb := NewTokenBucket(12_000, "peer") // 12 kB/s + gotPk := tb.ConsumePackets(100, 1_200) + if gotPk != 10 { // 12k / 1.2k = 10 + t.Fatalf("ConsumePackets first=%d want 10", gotPk) + } + + wait(105) + + gotPk2 := tb.ConsumePackets(100, 1_200) + if gotPk2 != 1 { + t.Fatalf("ConsumePackets after refill=%d want 1", gotPk2) + } +} + +func TestTokenBucket_TryConsumeBytes(t *testing.T) { + tb := NewTokenBucket(5_000, "peer") + if ok := tb.TryConsumeBytes(512); !ok { + t.Fatalf("TryConsumeBytes(512) = false, want true") + } + + left := tb.GetTokensLeft() + if left < 4_400 || left > 4_500 { + t.Fatalf("left ~ 4488.., got %d", left) + } + + if ok := tb.TryConsumeBytes(10_000); ok { + t.Fatalf("TryConsumeBytes big should be false") + } +} + +func TestTokenBucket_ParallelConsume_NoOveruse(t *testing.T) { + tb := NewTokenBucket(100_000, "peer") // 100k B/s + testDur := 100 * time.Millisecond + start := time.Now() + var consumed atomic.Int64 + + wg := sync.WaitGroup{} + workers := 8 + wg.Add(workers) + for w := 0; w < workers; w++ { + go func() { + defer wg.Done() + for time.Since(start) < testDur { + n := tb.ConsumeUpTo(1_500) + if n > 0 { + consumed.Add(int64(n)) + } else { + time.Sleep(200 * time.Microsecond) + } + } + }() + } + wg.Wait() + + got := consumed.Load() + if got > 120_000 { + t.Fatalf("parallel consumed=%d exceeds expected ~<=120000", got) + } +} diff --git a/adnl/rldp/client.go b/adnl/rldp/client.go index 9a077c86..c88edfa5 100644 --- a/adnl/rldp/client.go +++ b/adnl/rldp/client.go @@ -7,6 +7,8 @@ import ( "encoding/hex" "errors" "fmt" + "github.com/xssnick/tonutils-go/adnl/rldp/roundrobin" + "math" "reflect" "sort" "sync" @@ -31,14 +33,23 @@ type ADNL interface { var Logger = func(a ...any) {} -var PartSize = uint32(1 << 20) +var PartSize = uint32(256 << 10) + +var MultiFECMode = false // TODO: activate after some versions +var RoundRobinFECLimit = 50 * DefaultSymbolSize + +type fecEncoder interface { + GenSymbol(id uint32) []byte +} type activeTransferPart struct { - encoder *raptorq.Encoder + encoder fecEncoder seqno uint32 index uint32 - feq FECRaptorQ + fec FEC + fecSymbolSize uint32 + fecSymbolsCount uint32 lastConfirmRecvProcessed uint32 lastConfirmSeqnoProcessed uint32 @@ -48,6 +59,10 @@ type activeTransferPart struct { nextRecoverDelay int64 fastSeqnoTill uint32 recoveryReady atomic.Int32 + lossEWMA atomic.Uint64 + drrDeficit int64 + + sendClock *SendClock transfer *activeTransfer } @@ -55,10 +70,12 @@ type activeTransferPart struct { type activeTransfer struct { id []byte data []byte + totalSize uint64 timeoutAt int64 - currentPart atomic.Pointer[activeTransferPart] - rldp *RLDP + nextPartIndex uint32 + currentPart atomic.Pointer[activeTransferPart] + rldp *RLDP mx sync.Mutex } @@ -94,15 +111,24 @@ type RLDP struct { packetsSz uint64 rateLimit *TokenBucket - rateCtrl *AdaptiveRateController + rateCtrl *BBRv2Controller lastReport time.Time } +type fecDecoder interface { + AddSymbol(id uint32, data []byte) (bool, error) + Decode() (bool, []byte, error) +} + type decoderStreamPart struct { index uint32 - decoder *raptorq.Decoder + decoder fecDecoder + + fecDataSize uint32 + fecSymbolSize uint32 + fecSymbolsCount uint32 lastCompleteSeqno uint32 maxSeqno uint32 @@ -119,7 +145,8 @@ type decoderStream struct { lastMessageAt time.Time currentPart decoderStreamPart - messages chan *MessagePart + /// messages chan *MessagePart + msgBuf *Queue totalSize uint64 @@ -129,13 +156,15 @@ type decoderStream struct { mx sync.Mutex } -var MaxUnexpectedTransferSize uint64 = 1 << 16 // 64 KB -var MaxFECDataSize uint64 = 2 << 20 // 2 MB -var DefaultFECDataSize uint64 = 1 << 20 // 1 MB +var MaxUnexpectedTransferSize uint64 = 64 << 10 // 64 KB +var MaxFECDataSize uint32 = 2 << 20 // 2 MB var DefaultSymbolSize uint32 = 768 const _MTU = 1 << 37 +var MinRateBytesSec = int64(1 << 20) +var MaxRateBytesSec = int64(512 << 20) + func NewClient(a ADNL) *RLDP { r := &RLDP{ adnl: a, @@ -144,21 +173,21 @@ func NewClient(a ADNL) *RLDP { recvStreams: map[string]*decoderStream{}, expectedTransfers: map[string]*expectedTransfer{}, activateRecoverySender: make(chan bool, 1), - rateLimit: NewTokenBucket(10000, a.RemoteAddr()), - } - - r.rateCtrl = NewAdaptiveRateController(r.rateLimit, AdaptiveRateOptions{ - MinRate: 2500, - MaxRate: 0, - EnableSlowStart: true, - SlowStartMultiplier: 2.5, - TargetLoss: 0.05, - HighLoss: 0.25, - Deadband: 0.01, - DecreaseFactor: 0.15, - MildDecreaseFactor: 0.05, - IncreaseFactor: 0.067, - IncreaseOnlyWhenTokensBelow: 0.75, + rateLimit: NewTokenBucket(1<<20, a.RemoteAddr()), + } + + r.rateCtrl = NewBBRv2Controller(r.rateLimit, BBRv2Options{ + Name: r.adnl.RemoteAddr(), + MinRate: MinRateBytesSec, + MaxRate: MaxRateBytesSec, + HighLoss: 0.25, + Beta: 0.9, + DefaultRTTMs: 25, + BtlBwWindowSec: 10, + ProbeBwCycleMs: 200, + ProbeRTTDurationMs: 200, + MinRTTExpiryMs: 20_000, + MinSampleMs: 50, }) a.SetCustomMessageHandler(r.handleMessage) @@ -223,11 +252,6 @@ func (r *RLDP) handleMessage(msg *adnl.MessageCustom) error { switch m := msg.Data.(type) { case MessagePart: - fec, ok := m.FecType.(FECRaptorQ) - if !ok { - return fmt.Errorf("not supported fec type") - } - tm := time.Now() id := string(m.TransferID) @@ -251,13 +275,14 @@ func (r *RLDP) handleMessage(msg *adnl.MessageCustom) error { return fmt.Errorf("too big transfer size %d, max allowed %d", m.TotalSize, maxTransferSize) } - if m.TotalSize < uint64(fec.DataSize) { - return fmt.Errorf("bad rldp total size %d, expected at least %d", m.TotalSize, fec.DataSize) + qsz := int(m.FecType.GetSymbolsCount()) + 3 + if qsz > 1024 { + qsz = 1024 } stream = &decoderStream{ lastMessageAt: tm, - messages: make(chan *MessagePart, 256), + msgBuf: NewQueue(qsz), currentPart: decoderStreamPart{ index: 0, }, @@ -274,11 +299,8 @@ func (r *RLDP) handleMessage(msg *adnl.MessageCustom) error { r.mx.Unlock() } - select { - case stream.messages <- &m: - // put message to queue in case it will be locked by other processor - default: - } + // put a message to queue in case it will be locked by another processor + stream.msgBuf.Enqueue(&m) if !stream.mx.TryLock() { return nil @@ -286,212 +308,242 @@ func (r *RLDP) handleMessage(msg *adnl.MessageCustom) error { defer stream.mx.Unlock() for { - var part *MessagePart - select { - case part = <-stream.messages: - default: + part, ok := stream.msgBuf.Dequeue() + if !ok { return nil } - if stream.finishedAt != nil || stream.currentPart.index > part.Part { - if stream.currentPart.lastCompleteAt.Add(10 * time.Millisecond).Before(tm) { // to not send completions too often - var complete tl.Serializable = Complete{ - TransferID: part.TransferID, - Part: part.Part, - } + err := func() error { + if stream.finishedAt != nil || stream.currentPart.index > part.Part { + if stream.currentPart.lastCompleteAt.Add(10 * time.Millisecond).Before(tm) { // to not send completions too often + var complete tl.Serializable = Complete{ + TransferID: part.TransferID, + Part: part.Part, + } - if isV2 { - complete = CompleteV2(complete.(Complete)) - } + if isV2 { + complete = CompleteV2(complete.(Complete)) + } - // got packet for a finished part, let them know that it is completed, again - // TODO: just mark to auto send later? - err := r.adnl.SendCustomMessage(context.Background(), complete) - if err != nil { - return fmt.Errorf("failed to send rldp complete message: %w", err) + // got packet for a finished part, let them know that it is completed, again + err := r.adnl.SendCustomMessage(context.Background(), complete) + if err != nil { + return fmt.Errorf("failed to send rldp complete message: %w", err) + } + stream.currentPart.lastCompleteAt = tm } - stream.currentPart.lastCompleteAt = tm + return nil } - return nil - } - if part.Part > stream.currentPart.index { - return fmt.Errorf("received out of order part %d, expected %d", part.Part, stream.currentPart.index) - } - if part.TotalSize != stream.totalSize { - return fmt.Errorf("received part with bad total size %d, expected %d", part.TotalSize, stream.totalSize) - } - - if stream.currentPart.decoder == nil { - fec, ok := part.FecType.(FECRaptorQ) - if !ok { - return fmt.Errorf("not supported fec type in part: %d", part.Part) + if part.Part > stream.currentPart.index { + return fmt.Errorf("received out of order part %d, expected %d", part.Part, stream.currentPart.index) } - - if uint64(fec.DataSize) > stream.totalSize || fec.DataSize > uint32(MaxFECDataSize) || - fec.SymbolSize == 0 || fec.SymbolsCount == 0 { - return fmt.Errorf("invalid fec") + if part.TotalSize != stream.totalSize { + return fmt.Errorf("received part with bad total size %d, expected %d", part.TotalSize, stream.totalSize) } - dec, err := raptorq.NewRaptorQ(fec.SymbolSize).CreateDecoder(fec.DataSize) - if err != nil { - return fmt.Errorf("failed to init raptorq decoder: %w", err) - } - stream.currentPart.decoder = dec - Logger("[ID]", hex.EncodeToString(part.TransferID), "[RLDP] created decoder for part:", part.Part, "data size:", fec.DataSize, "symbol size:", fec.SymbolSize, "symbols:", fec.SymbolsCount) - } + if stream.currentPart.decoder == nil { + var decoderType uint32 + switch m.FecType.(type) { + case FECRaptorQ: + decoderType = 0 + case FECRoundRobin: + decoderType = 1 + default: + return fmt.Errorf("not supported fec type") + } - canTryDecode, err := stream.currentPart.decoder.AddSymbol(part.Seqno, part.Data) - if err != nil { - return fmt.Errorf("failed to add raptorq symbol %d: %w", part.Seqno, err) - } + if m.TotalSize < uint64(m.FecType.GetDataSize()) { + return fmt.Errorf("bad rldp total size %d, expected at least %d", m.TotalSize, m.FecType.GetDataSize()) + } - stream.lastMessageAt = tm - stream.currentPart.receivedNum++ + if uint64(m.FecType.GetDataSize()) > stream.totalSize || m.FecType.GetDataSize() > MaxFECDataSize || + m.FecType.GetSymbolSize() == 0 || m.FecType.GetSymbolsCount() == 0 { + return fmt.Errorf("invalid fec") + } - if canTryDecode { - tmd := time.Now() - decoded, data, err := stream.currentPart.decoder.Decode() - if err != nil { - return fmt.Errorf("failed to decode raptorq packet: %w", err) + var err error + var dec fecDecoder + if decoderType == 0 { + dec, err = raptorq.NewRaptorQ(m.FecType.GetSymbolSize()).CreateDecoder(m.FecType.GetDataSize()) + if err != nil { + return fmt.Errorf("failed to init raptorq decoder: %w", err) + } + } else { + dec, err = roundrobin.NewDecoder(m.FecType.GetSymbolSize(), m.FecType.GetDataSize()) + if err != nil { + return fmt.Errorf("failed to init round robin decoder: %w", err) + } + } + + stream.currentPart.fecDataSize = m.FecType.GetDataSize() + stream.currentPart.fecSymbolSize = m.FecType.GetSymbolSize() + stream.currentPart.fecSymbolsCount = m.FecType.GetSymbolsCount() + stream.currentPart.decoder = dec + + Logger("[ID]", hex.EncodeToString(part.TransferID), "[RLDP] created decoder for part:", part.Part, "data size:", stream.currentPart.fecDataSize, "symbol size:", stream.currentPart.fecSymbolSize, "symbols:", stream.currentPart.fecSymbolsCount) } - // it may not be decoded due to an unsolvable math system, it means we need more symbols - if decoded { - Logger("[RLDP] v2:", isV2, "part", part.Part, "decoded on seqno", part.Seqno, "symbols:", fec.SymbolsCount, "decode took", time.Since(tmd).String()) + canTryDecode, err := stream.currentPart.decoder.AddSymbol(part.Seqno, part.Data) + if err != nil { + return fmt.Errorf("failed to add raptorq symbol %d: %w", part.Seqno, err) + } - stream.currentPart = decoderStreamPart{ - index: stream.currentPart.index + 1, - } + stream.lastMessageAt = tm + stream.currentPart.receivedNum++ - if len(data) > 0 { - stream.parts = append(stream.parts, data) - stream.partsSize += uint64(len(data)) + if canTryDecode { + tmd := time.Now() + decoded, data, err := stream.currentPart.decoder.Decode() + if err != nil { + return fmt.Errorf("failed to decode raptorq packet: %w", err) } - var complete tl.Serializable = Complete{ - TransferID: part.TransferID, - Part: part.Part, - } + // it may not be decoded due to an unsolvable math system, it means we need more symbols + if decoded { + Logger("[RLDP] v2:", isV2, "part", part.Part, "decoded on seqno", part.Seqno, "symbols:", stream.currentPart.fecSymbolsCount, "decode took", time.Since(tmd).String()) - if isV2 { - complete = CompleteV2(complete.(Complete)) - } - _ = r.adnl.SendCustomMessage(context.Background(), complete) - - if stream.partsSize >= stream.totalSize { - stream.finishedAt = &tmd - stream.currentPart.decoder = nil - - r.mx.Lock() - for sID, s := range r.recvStreams { - // remove streams that was finished more than 15 sec ago or when it was no messages for more than 30 seconds. - if s.lastMessageAt.Add(30*time.Second).Before(tm) || - (s.finishedAt != nil && s.finishedAt.Add(15*time.Second).Before(tm)) { - delete(r.recvStreams, sID) - } + stream.currentPart = decoderStreamPart{ + index: stream.currentPart.index + 1, } - r.mx.Unlock() - if stream.partsSize > stream.totalSize { - return fmt.Errorf("received more data than expected, expected %d, got %d", stream.totalSize, stream.partsSize) - } - buf := make([]byte, stream.totalSize) - off := 0 - for _, p := range stream.parts { - off += copy(buf[off:], p) + if len(data) > 0 { + stream.parts = append(stream.parts, data) + stream.partsSize += uint64(len(data)) } - stream.parts = nil - stream.partsSize = 0 - var res any - if _, err = tl.Parse(&res, buf, true); err != nil { - return fmt.Errorf("failed to parse custom message: %w", err) + var complete tl.Serializable = Complete{ + TransferID: part.TransferID, + Part: part.Part, } - Logger("[RLDP] stream finished and parsed, processing transfer data", hex.EncodeToString(part.TransferID)) + // drop unprocessed messages related to this part + stream.msgBuf.Drain() - switch rVal := res.(type) { - case Query: - handler := r.onQuery - if handler != nil { - transferId := make([]byte, 32) - copy(transferId, part.TransferID) + if isV2 { + complete = CompleteV2(complete.(Complete)) + } + _ = r.adnl.SendCustomMessage(context.Background(), complete) - if err = handler(transferId, &rVal); err != nil { - Logger("failed to handle query: ", err) - } - } - case Answer: - qid := string(rVal.ID) + if stream.partsSize >= stream.totalSize { + stream.finishedAt = &tmd + stream.currentPart.decoder = nil r.mx.Lock() - req := r.activeRequests[qid] - if req != nil { - delete(r.activeRequests, qid) - delete(r.expectedTransfers, id) + for sID, s := range r.recvStreams { + // remove streams that was finished more than 15 sec ago or when it was no messages for more than 30 seconds. + if s.lastMessageAt.Add(30*time.Second).Before(tm) || + (s.finishedAt != nil && s.finishedAt.Add(15*time.Second).Before(tm)) { + delete(r.recvStreams, sID) + } } r.mx.Unlock() - if req != nil { - queryId := make([]byte, 32) - copy(queryId, rVal.ID) + if stream.partsSize > stream.totalSize { + return fmt.Errorf("received more data than expected, expected %d, got %d", stream.totalSize, stream.partsSize) + } + buf := make([]byte, stream.totalSize) + off := 0 + for _, p := range stream.parts { + off += copy(buf[off:], p) + } + stream.parts = nil + stream.partsSize = 0 + + var res any + if _, err = tl.Parse(&res, buf, true); err != nil { + return fmt.Errorf("failed to parse custom message: %w", err) + } - // if channel is full we sacrifice processing speed, responses better - req.result <- AsyncQueryResult{ - QueryID: queryId, - Result: rVal.Data, + Logger("[RLDP] stream finished and parsed, processing transfer data", hex.EncodeToString(part.TransferID)) + + switch rVal := res.(type) { + case Query: + handler := r.onQuery + if handler != nil { + transferId := make([]byte, 32) + copy(transferId, part.TransferID) + + if err = handler(transferId, &rVal); err != nil { + Logger("failed to handle query: ", err) + } + } + case Answer: + qid := string(rVal.ID) + + r.mx.Lock() + req := r.activeRequests[qid] + if req != nil { + delete(r.activeRequests, qid) + delete(r.expectedTransfers, id) } + r.mx.Unlock() + + if req != nil { + queryId := make([]byte, 32) + copy(queryId, rVal.ID) + + // if a channel is full, we sacrifice processing speed, responses better + req.result <- AsyncQueryResult{ + QueryID: queryId, + Result: rVal.Data, + } + } + default: + Logger("[RLDP] skipping unwanted rldp message of type", reflect.TypeOf(res).String()) } - default: - Logger("[RLDP] skipping unwanted rldp message of type", reflect.TypeOf(res).String()) } + return nil + } else { + Logger("[RLDP] part ", part.Part, "decode attempt failure on seqno", part.Seqno, "symbols:", stream.currentPart.fecSymbolsCount, "decode took", time.Since(tmd).String()) } - return nil - } else { - Logger("[RLDP] part ", part.Part, "decode attempt failure on seqno", part.Seqno, "symbols:", fec.SymbolsCount, "decode took", time.Since(tmd).String()) } - } - if part.Seqno > stream.currentPart.maxSeqno { - diff := part.Seqno - stream.currentPart.maxSeqno - if diff >= 32 { - stream.currentPart.receivedMask = 0 - } else { - stream.currentPart.receivedMask <<= diff + if part.Seqno > stream.currentPart.maxSeqno { + diff := part.Seqno - stream.currentPart.maxSeqno + if diff >= 32 { + stream.currentPart.receivedMask = 0 + } else { + stream.currentPart.receivedMask <<= diff + } + stream.currentPart.maxSeqno = part.Seqno } - stream.currentPart.maxSeqno = part.Seqno - } - if offset := stream.currentPart.maxSeqno - part.Seqno; offset < 32 { - stream.currentPart.receivedMask |= 1 << offset - } + if offset := stream.currentPart.maxSeqno - part.Seqno; offset < 32 { + stream.currentPart.receivedMask |= 1 << offset + } - // send confirm for each 10 packets or after 20 ms - if stream.currentPart.receivedNum-stream.currentPart.receivedNumConfirmed >= 10 || - stream.currentPart.lastConfirmAt.Add(20*time.Millisecond).Before(tm) { - var confirm tl.Serializable - if isV2 { - confirm = ConfirmV2{ - TransferID: part.TransferID, - Part: part.Part, - MaxSeqno: stream.currentPart.maxSeqno, - ReceivedMask: stream.currentPart.receivedMask, - ReceivedCount: stream.currentPart.receivedNum, + // send confirm for each 10 packets or after 20 ms + if stream.currentPart.receivedNum-stream.currentPart.receivedNumConfirmed >= 10 || + stream.currentPart.lastConfirmAt.Add(20*time.Millisecond).Before(tm) { + var confirm tl.Serializable + if isV2 { + confirm = ConfirmV2{ + TransferID: part.TransferID, + Part: part.Part, + MaxSeqno: stream.currentPart.maxSeqno, + ReceivedMask: stream.currentPart.receivedMask, + ReceivedCount: stream.currentPart.receivedNum, + } + } else { + confirm = Confirm{ + TransferID: part.TransferID, + Part: part.Part, + Seqno: stream.currentPart.maxSeqno, + } } - } else { - confirm = Confirm{ - TransferID: part.TransferID, - Part: part.Part, - Seqno: stream.currentPart.maxSeqno, + // we don't care in case of error, not so critical + err = r.adnl.SendCustomMessage(context.Background(), confirm) + if err == nil { + stream.currentPart.receivedNumConfirmed = stream.currentPart.receivedNum + stream.currentPart.lastConfirmAt = tm } } - // we don't care in case of error, not so critical - err = r.adnl.SendCustomMessage(context.Background(), confirm) - if err == nil { - stream.currentPart.receivedNumConfirmed = stream.currentPart.receivedNum - stream.currentPart.lastConfirmAt = tm - } + + return nil + }() + if err != nil { + Logger("[RLDP] transfer", hex.EncodeToString(part.TransferID), "process msg part:", part.Part, "error:", err.Error()) } } case Complete: // receiver has fully received transfer part, send new part or close our stream if done @@ -544,23 +596,98 @@ func (r *RLDP) handleMessage(msg *adnl.MessageCustom) error { } part := t.getCurrentPart() - if part == nil { + if part == nil || part.index != m.Part { break } - lastProc := atomic.LoadUint32(&part.lastConfirmSeqnoProcessed) - if isV2 && lastProc+32 <= m.MaxSeqno && - atomic.CompareAndSwapUint32(&part.lastConfirmSeqnoProcessed, lastProc, m.MaxSeqno) { + if isV2 { + for { + prevSeq := atomic.LoadUint32(&part.lastConfirmSeqnoProcessed) + prevRecv := atomic.LoadUint32(&part.lastConfirmRecvProcessed) - total := (m.MaxSeqno - lastProc) + 1 - prevRecv := atomic.SwapUint32(&part.lastConfirmRecvProcessed, m.ReceivedCount) - var recvDelta uint32 - if m.ReceivedCount >= prevRecv { - recvDelta = m.ReceivedCount - prevRecv - } + advancedSeq := m.MaxSeqno > prevSeq + advancedRecv := m.ReceivedCount > prevRecv + if !advancedSeq && !advancedRecv { + break + } + + if advancedSeq && !atomic.CompareAndSwapUint32(&part.lastConfirmSeqnoProcessed, prevSeq, m.MaxSeqno) { + continue + } + + if advancedRecv { + atomic.StoreUint32(&part.lastConfirmRecvProcessed, m.ReceivedCount) + } + + var seqDelta int64 + if advancedSeq { + seqDelta = int64(m.MaxSeqno - prevSeq) + } + + var recvDelta int64 + if advancedRecv { + recvDelta = int64(m.ReceivedCount - prevRecv) + } + + if seqDelta < 0 { + seqDelta = 0 + } + if recvDelta < 0 { + recvDelta = 0 + } + + totalDelta := seqDelta + if totalDelta < recvDelta { + totalDelta = recvDelta + } + if totalDelta <= 0 { + break + } + + if tms, ok := part.sendClock.SentAt(m.MaxSeqno); ok { + r.rateCtrl.ObserveRTT(time.Now().UnixMilli() - tms) + } + + r.rateCtrl.ObserveDelta(totalDelta*int64(part.fecSymbolSize), recvDelta*int64(part.fecSymbolSize)) + + loss := 1.0 + if totalDelta > 0 { + loss = 1.0 - float64(recvDelta)/float64(totalDelta) + if loss < 0 { + loss = 0 + } + if loss > 1 { + loss = 1 + } + } + + const alpha = 0.2 + prev := math.Float64frombits(part.lossEWMA.Load()) + if prev == 0 { + prev = loss + } + ew := prev*(1-alpha) + loss*alpha + part.lossEWMA.Store(math.Float64bits(ew)) + + // (3% base + 1.5 * loss), limit 30% + k := float64(part.fecSymbolsCount) + overhead := 0.03 + 1.5*ew + if overhead < 0.01 { + overhead = 0.01 + } + + if overhead > 0.30 { + overhead = 0.30 + } + + target := uint32(k + math.Ceil(k*overhead)) + + cur := atomic.LoadUint32(&part.fastSeqnoTill) + if target > cur { + atomic.StoreUint32(&part.fastSeqnoTill, target) + } - if total > 0 { - r.rateCtrl.ObserveDelta(total, recvDelta) + break } } @@ -585,6 +712,9 @@ func (r *RLDP) recoverySender() { ticker := time.NewTicker(1 * time.Millisecond) defer ticker.Stop() + // round-robin head for fair recovery + var rrHead uint32 + active := false for { select { @@ -657,66 +787,123 @@ func (r *RLDP) recoverySender() { sort.Slice(transfersToProcess, func(i, j int) bool { // recently confirmed transfers are prioritized - return atomic.LoadInt64(&transfersToProcess[i].lastConfirmAt) > atomic.LoadInt64(&transfersToProcess[j].lastConfirmAt) + return atomic.LoadInt64(&transfersToProcess[i].lastConfirmAt) > + atomic.LoadInt64(&transfersToProcess[j].lastConfirmAt) }) isV2 := r.useV2.Load() == 1 - loop: - for _, part := range transfersToProcess { - numToResend := 1 - if sc := part.feq.SymbolsCount / 300; sc > 1 { // up to 0.3% per loop - numToResend = int(sc) - } + n := len(transfersToProcess) + if n > 0 { + start := int(rrHead % uint32(n)) + drained := false + lastServedIdx := -1 + + sendLoop: + for i := 0; i < n; i++ { + idx := (start + i) % n + part := transfersToProcess[idx] + + seqno := atomic.LoadUint32(&part.seqno) + + quantum := int64(1) + if sc := part.fecSymbolsCount / 200; sc > 1 { + quantum = int64(sc) + } - seqno := atomic.LoadUint32(&part.seqno) - if seqno < part.fastSeqnoTill { - if diff := int(part.fastSeqnoTill - seqno); diff > numToResend { - numToResend = diff + if seqno < part.fastSeqnoTill { + fastDiff := int64(part.fastSeqnoTill - seqno) + if fastDiff > quantum { + quantum = fastDiff + } } - } - consumed := false - for i := 0; i < numToResend; i++ { - if !r.rateLimit.TryConsume() { - consumed = true - break + part.drrDeficit += quantum + if part.drrDeficit <= 0 { + continue } - p := MessagePart{ - TransferID: part.transfer.id, - FecType: part.feq, - Part: part.index, - TotalSize: uint64(len(part.transfer.data)), - Seqno: seqno, - Data: part.encoder.GenSymbol(seqno), + allow := part.drrDeficit + if allow > int64(math.MaxInt32) { + allow = int64(math.MaxInt32) } - seqno++ - var msgPart tl.Serializable = p - if isV2 { - msgPart = MessagePartV2(p) + ms = time.Now().UnixMilli() + + requested := int(allow) + consumed := r.rateLimit.ConsumePackets(requested, int(part.fecSymbolSize)) + if consumed == 0 { + drained = true + break } - if err := r.adnl.SendCustomMessage(closerCtx, msgPart); err != nil { - Logger("failed to send recovery message part", p.Seqno, err.Error()) - break loop + prevSeqno := seqno + for j := 0; j < consumed; j++ { + p := MessagePart{ + TransferID: part.transfer.id, + FecType: part.fec, + Part: part.index, + TotalSize: part.transfer.totalSize, + Seqno: seqno, + Data: part.encoder.GenSymbol(seqno), + } + seqno++ + + var msgPart tl.Serializable = p + if isV2 { + msgPart = MessagePartV2(p) + } + + part.sendClock.OnSend(p.Seqno, ms) + if err := r.adnl.SendCustomMessage(closerCtx, msgPart); err != nil { + Logger("failed to send recovery message part", p.Seqno, err.Error()) + drained = true + break sendLoop + } } - } - if atomic.LoadUint32(&part.seqno) < seqno { - // we sent something, so considering to be updated - part.lastRecoverAt = ms - if consumed { - part.nextRecoverDelay = 10 - } else { - part.nextRecoverDelay = 50 + if seqno > prevSeqno { + sent := int64(seqno - prevSeqno) + part.drrDeficit -= sent + if part.drrDeficit < 0 { + part.drrDeficit = 0 + } + + base := r.rateCtrl.CurrentMinRTT() + if base <= 0 { + base = r.rateCtrl.opts.DefaultRTTMs + if base <= 0 { + base = 25 + } + } + + minGap := max64(8, base/4) + maxGap := max64(20, base/2) + + if consumed > 0 { + part.nextRecoverDelay = minGap + } else { + part.nextRecoverDelay = maxGap + } + + part.lastRecoverAt = ms + lastServedIdx = idx + } + atomic.StoreUint32(&part.seqno, seqno) + + if consumed < requested { + drained = true + break } } - atomic.StoreUint32(&part.seqno, seqno) - if consumed { - break + if lastServedIdx >= 0 { + rrHead = uint32((lastServedIdx + 1) % n) + } + if drained && lastServedIdx < 0 { + rrHead = uint32(start) } + } else { + rrHead = 0 } if len(timedOut) > 0 || len(timedOutReq) > 0 || len(timedOutExp) > 0 { @@ -733,9 +920,30 @@ func (r *RLDP) recoverySender() { r.mx.Unlock() } + if len(transfersToProcess) == 0 && r.rateLimit.GetTokensLeft() > int64(DefaultSymbolSize)*20 { + r.rateCtrl.SetAppLimited(true) + } else { + r.rateCtrl.SetAppLimited(false) + } + + for i := range transfersToProcess { + transfersToProcess[i] = nil + } transfersToProcess = transfersToProcess[:0] + + for i := range timedOut { + timedOut[i] = nil + } timedOut = timedOut[:0] + + for i := range timedOutReq { + timedOutReq[i] = "" + } timedOutReq = timedOutReq[:0] + + for i := range timedOutExp { + timedOutExp[i] = "" + } timedOutExp = timedOutExp[:0] } } @@ -746,6 +954,7 @@ func (r *RLDP) startTransfer(ctx context.Context, transferId, data []byte, recov id: transferId, timeoutAt: recoverTimeoutAt * 1000, // ms data: data, + totalSize: uint64(len(data)), rldp: r, } @@ -783,40 +992,80 @@ func (t *activeTransfer) prepareNextPart() (bool, error) { return false, nil // fmt.Errorf("transfer timed out") } - partIndex := uint32(0) - if cp := t.getCurrentPart(); cp != nil { - partIndex = cp.index + 1 - } - - if len(t.data) <= int(partIndex*PartSize) { - // all parts sent + if len(t.data) == 0 { return false, nil } - payload := t.data[partIndex*PartSize:] + partIndex := t.nextPartIndex + + payload := t.data if len(payload) > int(PartSize) { payload = payload[:PartSize] } - enc, err := raptorq.NewRaptorQ(DefaultSymbolSize).CreateEncoder(payload) - if err != nil { - return false, fmt.Errorf("failed to create raptorq object encoder: %w", err) + if len(payload) == 0 { + return false, nil } + remaining := t.data[len(payload):] - part := activeTransferPart{ - encoder: enc, - seqno: 0, - index: partIndex, - feq: FECRaptorQ{ + cnt := uint32(len(payload))/DefaultSymbolSize + 1 + + var err error + var enc fecEncoder + var fec FEC + + //goland:noinspection GoBoolExpressions + if MultiFECMode && len(payload) < int(RoundRobinFECLimit) { + enc, err = roundrobin.NewEncoder(payload, DefaultSymbolSize) + if err != nil { + return false, fmt.Errorf("failed to create rr object encoder: %w", err) + } + + fec = FECRoundRobin{ + DataSize: uint32(len(payload)), + SymbolSize: DefaultSymbolSize, + SymbolsCount: cnt, + } + } else { + enc, err = raptorq.NewRaptorQ(DefaultSymbolSize).CreateEncoder(payload) + if err != nil { + return false, fmt.Errorf("failed to create raptorq object encoder: %w", err) + } + + fec = FECRaptorQ{ DataSize: uint32(len(payload)), SymbolSize: DefaultSymbolSize, - SymbolsCount: enc.BaseSymbolsNum(), - }, - nextRecoverDelay: 30, - fastSeqnoTill: enc.BaseSymbolsNum() + enc.BaseSymbolsNum()/33, // +3% + SymbolsCount: cnt, + } + } + + part := activeTransferPart{ + encoder: enc, + seqno: 0, + index: partIndex, + fec: fec, + fecSymbolsCount: fec.GetSymbolsCount(), + fecSymbolSize: fec.GetSymbolSize(), + nextRecoverDelay: 15, + fastSeqnoTill: cnt + cnt/50 + 1, // +2% transfer: t, } + pt := uint32(1) << uint32(math.Ceil(math.Log2(float64(part.fecSymbolsCount)))) + if pt > 16<<10 { + pt = 16 << 10 + } else if pt < 64 { + pt = 64 + } + part.sendClock = NewSendClock(int(pt)) + + if len(remaining) == 0 { + t.data = nil + } else { + t.data = remaining + } + + t.nextPartIndex++ t.currentPart.Store(&part) return true, nil } @@ -829,34 +1078,47 @@ func (r *RLDP) sendFastSymbols(ctx context.Context, transfer *activeTransfer) er p := MessagePart{ TransferID: transfer.id, - FecType: part.feq, + FecType: part.fec, Part: part.index, - TotalSize: uint64(len(transfer.data)), + TotalSize: transfer.totalSize, } sc := part.fastSeqnoTill isV2 := r.useV2.Load() == 1 - for i := uint32(0); i < sc; i++ { - if !r.rateLimit.TryConsume() { - // we cannot send right now, so we enqueue it - sc = i + maxBatch := int(^uint(0) >> 1) + seqno := uint32(0) + for seqno < sc { + remaining := uint64(sc - seqno) + if remaining > uint64(maxBatch) { + remaining = uint64(maxBatch) + } + + batch := r.rateLimit.ConsumePackets(int(remaining), int(part.fecSymbolSize)) + if batch == 0 { break } + now := time.Now().UnixMilli() - p.Seqno = i - p.Data = part.encoder.GenSymbol(i) + for i := 0; i < batch; i++ { + currentSeqno := seqno + p.Seqno = currentSeqno + p.Data = part.encoder.GenSymbol(currentSeqno) - var msgPart tl.Serializable = p - if isV2 { - msgPart = MessagePartV2(p) - } + var msgPart tl.Serializable = p + if isV2 { + msgPart = MessagePartV2(p) + } - if err := r.adnl.SendCustomMessage(ctx, msgPart); err != nil { - return fmt.Errorf("failed to send message part %d: %w", i, err) + part.sendClock.OnSend(currentSeqno, now) + if err := r.adnl.SendCustomMessage(ctx, msgPart); err != nil { + return fmt.Errorf("failed to send message part %d: %w", currentSeqno, err) + } + + seqno++ } } - atomic.StoreUint32(&part.seqno, sc) + atomic.StoreUint32(&part.seqno, seqno) part.recoveryReady.Store(1) select { @@ -895,15 +1157,16 @@ type AsyncQueryResult struct { } func (r *RLDP) DoQueryAsync(ctx context.Context, maxAnswerSize uint64, id []byte, query tl.Serializable, result chan<- AsyncQueryResult) error { - timeout, ok := ctx.Deadline() - if !ok { - timeout = time.Now().Add(15 * time.Second) - } - if len(id) != 32 { return errors.New("invalid id") } + now := time.Now() + timeout, ok := ctx.Deadline() + if !ok { + timeout = now.Add(15 * time.Second) + } + q := &Query{ ID: id, MaxAnswerSize: maxAnswerSize, @@ -911,6 +1174,11 @@ func (r *RLDP) DoQueryAsync(ctx context.Context, maxAnswerSize uint64, id []byte Data: query, } + if uxMin := now.Unix() + 2; int64(q.Timeout) < uxMin { + // because timeout in seconds, we should add some to avoid an early drop + q.Timeout = uint32(uxMin) + } + data, err := tl.Serialize(q, true) if err != nil { return fmt.Errorf("failed to serialize query: %w", err) @@ -968,6 +1236,11 @@ func (r *RLDP) SendAnswer(ctx context.Context, maxAnswerSize uint64, timeoutAt u tm = int64(timeoutAt) } + if minT := time.Now().Unix() + 1; tm < minT { + // give at least 1 sec in case of a clock problem + tm = minT + } + if err = r.startTransfer(ctx, reverseTransferId(toTransferId), data, tm); err != nil { return fmt.Errorf("failed to send partitioned answer: %w", err) } diff --git a/adnl/rldp/client_test.go b/adnl/rldp/client_test.go index b7fbb26e..f0c7bcc4 100644 --- a/adnl/rldp/client_test.go +++ b/adnl/rldp/client_test.go @@ -6,15 +6,17 @@ import ( "crypto/ed25519" "crypto/rand" "crypto/sha256" - "encoding/hex" "errors" + "fmt" "github.com/xssnick/raptorq" "github.com/xssnick/tonutils-go/adnl" "github.com/xssnick/tonutils-go/tl" "log" + "net" "net/http" "net/url" "reflect" + "runtime" "strings" "testing" "time" @@ -24,6 +26,8 @@ func init() { tl.Register(testRequest{}, "http.request id:int256 method:string url:string http_version:string headers:(vector http.header) = http.Response") tl.Register(testResponse{}, "http.response http_version:string status_code:int reason:string headers:(vector http.header) no_payload:Bool = http.Response") tl.Register(testHeader{}, "") + tl.Register(benchRequest{}, "benchRequest") + tl.Register(benchResponse{}, "benchResponse") } type MockADNL struct { @@ -63,12 +67,21 @@ func (m MockADNL) SendCustomMessage(ctx context.Context, req tl.Serializable) er func (m MockADNL) Close() { } +type benchRequest struct { + WantLen uint32 `tl:"int"` +} + +type benchResponse struct { + Data []byte `tl:"bytes"` +} + type testRequest struct { ID []byte `tl:"int256"` Method string `tl:"string"` URL string `tl:"string"` Version string `tl:"string"` Headers []testHeader `tl:"vector struct"` + RespSz uint64 `tl:"long"` } type testResponse struct { @@ -665,7 +678,7 @@ func TestRLDP_ClientServer(t *testing.T) { res := testResponse{ Version: "HTTP/1.1", StatusCode: int32(200), - Reason: "test ok:" + hex.EncodeToString(q.ID) + q.URL, + Reason: q.URL, Headers: []testHeader{{"test", "test"}}, NoPayload: true, } @@ -710,12 +723,13 @@ func TestRLDP_ClientServer(t *testing.T) { t.Fatal("bad client execution, err: ", err) } - if resp.Reason != "test ok:"+hex.EncodeToString(make([]byte, 32))+u { + if resp.Reason != u { t.Fatal("bad response data") } }) Logger = log.Println + t.Run("big multipart 10mb", func(t *testing.T) { old := MaxUnexpectedTransferSize MaxUnexpectedTransferSize = 1 << 30 @@ -736,9 +750,226 @@ func TestRLDP_ClientServer(t *testing.T) { t.Fatal("bad client execution, err: ", err) } - if resp.Reason != "test ok:"+hex.EncodeToString(make([]byte, 32))+u { + if resp.Reason != u { + t.Fatal("bad response data") + } + }) + + t.Run("big multipart 4mb rr", func(t *testing.T) { + old := MaxUnexpectedTransferSize + MaxUnexpectedTransferSize = 1 << 30 + MultiFECMode = true + defer func() { + MaxUnexpectedTransferSize = old + MultiFECMode = false + RoundRobinFECLimit = 1 << 30 + }() + + u := strings.Repeat("a", 4*1024*1024) + + var resp testResponse + err := cr.DoQuery(context.Background(), 4096+uint64(len(u)), testRequest{ + ID: make([]byte, 32), + Method: "GET", + URL: u, + Version: "1", + }, &resp) + if err != nil { + t.Fatal("bad client execution, err: ", err) + } + + if resp.Reason != u { t.Fatal("bad response data") } }) } + +func BenchmarkRLDP_ClientServer(b *testing.B) { + old := MaxUnexpectedTransferSize + MaxUnexpectedTransferSize = 1 << 30 + defer func() { + MaxUnexpectedTransferSize = old + }() + + defaultSizes := []uint32{16 << 10, 256 << 10, 1 << 20, 4 << 20, 10 << 20} + + scenarios := []struct { + name string + sizes []uint32 + setup func(*testing.B) (*RLDP, func()) + withParallel bool + }{ + { + name: "loopback_rr", + sizes: defaultSizes, + setup: func(b *testing.B) (*RLDP, func()) { + oldLim := RoundRobinFECLimit + RoundRobinFECLimit = 2 << 30 + MultiFECMode = true + rl, end := setupLoopbackBenchmark(b) + + return rl, func() { + end() + RoundRobinFECLimit = oldLim + } + }, + withParallel: true, + }, + { + name: "loopback_raptorq", + sizes: defaultSizes, + setup: setupLoopbackBenchmark, + withParallel: true, + }, + { + // it requires some time to speedup by bbr, so will show a low rate + name: "netem_loss_raptorq", + sizes: []uint32{4 << 20}, + setup: func(tb *testing.B) (*RLDP, func()) { + return setupNetemBenchmark(tb, 0.02, 50*time.Millisecond, 5*time.Millisecond) + }, + withParallel: true, + }, + } + + for _, sc := range scenarios { + sc := sc + b.Run(sc.name, func(b *testing.B) { + client, cleanup := sc.setup(b) + defer cleanup() + runRLDPBenchSizes(b, client, sc.sizes, sc.withParallel) + }) + } +} + +func runRLDPBenchSizes(b *testing.B, client *RLDP, sizes []uint32, withParallel bool) { + for _, sz := range sizes { + b.Run(fmt.Sprintf("resp=%dKB", sz>>10), func(b *testing.B) { + b.SetBytes(int64(sz)) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + var resp benchResponse + if err := client.DoQuery(context.Background(), 1<<30, benchRequest{ + WantLen: sz, + }, &resp); err != nil { + b.Fatalf("client exec err: %v", err) + } + } + }) + + if withParallel { + b.Run(fmt.Sprintf("resp=%dKB/parallel", sz>>10), func(b *testing.B) { + b.SetBytes(int64(sz)) + b.SetParallelism(runtime.NumCPU()) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + var resp benchResponse + if err := client.DoQuery(context.Background(), 1<<30, benchRequest{ + WantLen: sz, + }, &resp); err != nil { + b.Fatalf("client exec err: %v", err) + } + } + }) + }) + } + } +} + +func configureBenchServer(g *adnl.Gateway) { + g.SetConnectionHandler(func(client adnl.Peer) error { + conn := NewClientV2(client) + conn.SetOnQuery(func(transferId []byte, query *Query) error { + q := query.Data.(benchRequest) + res := benchResponse{Data: make([]byte, q.WantLen)} + return conn.SendAnswer(context.Background(), query.MaxAnswerSize, query.Timeout, query.ID, transferId, res) + }) + return nil + }) +} + +func setupLoopbackBenchmark(b *testing.B) (*RLDP, func()) { + srvPub, srvKey, err := ed25519.GenerateKey(nil) + if err != nil { + b.Fatal(err) + } + + _, cliKey, err := ed25519.GenerateKey(nil) + if err != nil { + b.Fatal(err) + } + + srv := adnl.NewGateway(srvKey) + if err := srv.StartServer("127.0.0.1:19157"); err != nil { + b.Fatal(err) + } + configureBenchServer(srv) + + cliGateway := adnl.NewGateway(cliKey) + if err := cliGateway.StartClient(); err != nil { + b.Fatal(err) + } + + cli, err := cliGateway.RegisterClient("127.0.0.1:19157", srvPub) + if err != nil { + b.Fatal(err) + } + + client := NewClientV2(cli) + + cleanup := func() { + client.Close() + _ = cliGateway.Close() + _ = srv.Close() + } + + return client, cleanup +} + +func setupNetemBenchmark(b *testing.B, loss float64, baseDelay, jitter time.Duration) (*RLDP, func()) { + srvPub, srvKey, err := ed25519.GenerateKey(nil) + if err != nil { + b.Fatal(err) + } + + _, cliKey, err := ed25519.GenerateKey(nil) + if err != nil { + b.Fatal(err) + } + + srvConn, cliConn := newMemPacketConnPair(loss, baseDelay, jitter, 512<<10) + + srv := adnl.NewGatewayWithNetManager(srvKey, adnl.NewSingleNetReader(func(string) (net.PacketConn, error) { + return srvConn, nil + })) + if err := srv.StartServer("127.0.0.1:19158"); err != nil { + b.Fatal(err) + } + configureBenchServer(srv) + + cliGateway := adnl.NewGatewayWithNetManager(cliKey, adnl.NewSingleNetReader(func(string) (net.PacketConn, error) { + return cliConn, nil + })) + if err := cliGateway.StartClient(); err != nil { + b.Fatal(err) + } + + cli, err := cliGateway.RegisterClient("127.0.0.1:19158", srvPub) + if err != nil { + b.Fatal(err) + } + + client := NewClientV2(cli) + + cleanup := func() { + client.Close() + _ = cliGateway.Close() + _ = srv.Close() + } + + return client, cleanup +} diff --git a/adnl/rldp/fec.go b/adnl/rldp/fec.go index e6ea8313..4f7ddd3e 100644 --- a/adnl/rldp/fec.go +++ b/adnl/rldp/fec.go @@ -13,12 +13,30 @@ func init() { tl.Register(FECOnline{}, "fec.online data_size:int symbol_size:int symbols_count:int = fec.Type") } +type FEC interface { + GetDataSize() uint32 + GetSymbolSize() uint32 + GetSymbolsCount() uint32 +} + type FECRaptorQ struct { DataSize uint32 // `tl:"int"` SymbolSize uint32 // `tl:"int"` SymbolsCount uint32 // `tl:"int"` } +func (f FECRaptorQ) GetDataSize() uint32 { + return f.DataSize +} + +func (f FECRaptorQ) GetSymbolSize() uint32 { + return f.SymbolSize +} + +func (f FECRaptorQ) GetSymbolsCount() uint32 { + return f.SymbolsCount +} + func (f *FECRaptorQ) Parse(data []byte) ([]byte, error) { if len(data) < 12 { return nil, fmt.Errorf("fec raptor data too short") @@ -44,6 +62,37 @@ type FECRoundRobin struct { SymbolsCount uint32 `tl:"int"` } +func (f FECRoundRobin) GetDataSize() uint32 { + return f.DataSize +} + +func (f FECRoundRobin) GetSymbolSize() uint32 { + return f.SymbolSize +} + +func (f FECRoundRobin) GetSymbolsCount() uint32 { + return f.SymbolsCount +} + +func (f *FECRoundRobin) Parse(data []byte) ([]byte, error) { + if len(data) < 12 { + return nil, fmt.Errorf("fec rr data too short") + } + f.DataSize = binary.LittleEndian.Uint32(data[:4]) + f.SymbolSize = binary.LittleEndian.Uint32(data[4:8]) + f.SymbolsCount = binary.LittleEndian.Uint32(data[8:12]) + return data[12:], nil +} + +func (f *FECRoundRobin) Serialize(buf *bytes.Buffer) error { + tmp := make([]byte, 12) + binary.LittleEndian.PutUint32(tmp[0:4], f.DataSize) + binary.LittleEndian.PutUint32(tmp[4:8], f.SymbolSize) + binary.LittleEndian.PutUint32(tmp[8:12], f.SymbolsCount) + buf.Write(tmp) + return nil +} + type FECOnline struct { DataSize uint32 `tl:"int"` SymbolSize uint32 `tl:"int"` diff --git a/adnl/rldp/netem_test.go b/adnl/rldp/netem_test.go new file mode 100644 index 00000000..9d7735ff --- /dev/null +++ b/adnl/rldp/netem_test.go @@ -0,0 +1,333 @@ +package rldp + +import ( + "container/heap" + "math/rand" + "net" + "sync" + "sync/atomic" + "time" +) + +type timeoutError struct{} + +func (timeoutError) Error() string { return "i/o timeout" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +var errTimeout timeoutError + +type memPacket struct { + data []byte + addr net.Addr +} + +type scheduledPacket struct { + when time.Time + pkt memPacket +} + +type packetQueue []scheduledPacket + +func (pq packetQueue) Len() int { return len(pq) } + +func (pq packetQueue) Less(i, j int) bool { return pq[i].when.Before(pq[j].when) } + +func (pq packetQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] } + +func (pq *packetQueue) Push(x interface{}) { + *pq = append(*pq, x.(scheduledPacket)) +} + +func (pq *packetQueue) Pop() interface{} { + old := *pq + n := len(old) + item := old[n-1] + *pq = old[:n-1] + return item +} + +type memPacketConn struct { + name string + inbox chan memPacket + peer *memPacketConn + closeOnce sync.Once + closed atomic.Bool + closeCh chan struct{} + + baseDelay time.Duration + jitter time.Duration + loss float64 + + rngMu sync.Mutex + rng *rand.Rand + + readDeadline atomic.Int64 + writeDeadline atomic.Int64 + + localAddr *net.UDPAddr + + dispatcherOnce sync.Once + wakeCh chan struct{} + queueMu sync.Mutex + queue packetQueue +} + +func newMemPacketConnPair(loss float64, baseDelay, jitter time.Duration, buf int) (*memPacketConn, *memPacketConn) { + if buf <= 0 { + buf = 1024 + } + now := time.Now().UnixNano() + basePort := 20000 + int(now%10000) + a := &memPacketConn{ + name: "server", + inbox: make(chan memPacket, buf), + closeCh: make(chan struct{}), + baseDelay: baseDelay, + jitter: jitter, + loss: loss, + rng: rand.New(rand.NewSource(now)), + localAddr: &net.UDPAddr{IP: net.IPv4(10, 0, 0, 1), Port: basePort}, + wakeCh: make(chan struct{}, 1), + } + b := &memPacketConn{ + name: "client", + inbox: make(chan memPacket, buf), + closeCh: make(chan struct{}), + baseDelay: baseDelay, + jitter: jitter, + loss: loss, + rng: rand.New(rand.NewSource(now + 1)), + localAddr: &net.UDPAddr{IP: net.IPv4(10, 0, 0, 2), Port: basePort + 1}, + wakeCh: make(chan struct{}, 1), + } + a.peer = b + b.peer = a + a.startDispatcher() + b.startDispatcher() + return a, b +} + +func (c *memPacketConn) randFloat() float64 { + c.rngMu.Lock() + f := c.rng.Float64() + c.rngMu.Unlock() + return f +} + +func (c *memPacketConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *memPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + if c.closed.Load() { + return 0, nil, net.ErrClosed + } + var ( + timer *time.Timer + timerCh <-chan time.Time + ) + if deadline := c.readDeadline.Load(); deadline > 0 { + d := time.Until(time.Unix(0, deadline)) + if d <= 0 { + return 0, nil, errTimeout + } + timer = time.NewTimer(d) + timerCh = timer.C + defer timer.Stop() + } + select { + case pkt, ok := <-c.inbox: + if !ok { + return 0, nil, net.ErrClosed + } + n := copy(b, pkt.data) + return n, pkt.addr, nil + case <-c.closeCh: + return 0, nil, net.ErrClosed + case <-timerCh: + return 0, nil, errTimeout + } +} + +func (c *memPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) { + if c.closed.Load() { + return 0, net.ErrClosed + } + peer := c.peer + if peer == nil || peer.closed.Load() { + return 0, net.ErrClosed + } + if deadline := c.writeDeadline.Load(); deadline > 0 && time.Now().After(time.Unix(0, deadline)) { + return 0, errTimeout + } + if len(b) == 0 { + return 0, nil + } + if c.loss > 0 && c.randFloat() < c.loss { + return len(b), nil + } + payload := make([]byte, len(b)) + copy(payload, b) + delay := c.baseDelay + if c.jitter > 0 { + j := time.Duration((c.randFloat()*2 - 1) * float64(c.jitter)) + delay += j + if delay < 0 { + delay = 0 + } + } + peer.startDispatcher() + pkt := memPacket{data: payload, addr: c.localAddr} + peer.enqueuePacket(pkt, time.Now().Add(delay)) + return len(b), nil +} + +func (c *memPacketConn) Close() error { + c.closeOnce.Do(func() { + c.closed.Store(true) + close(c.closeCh) + close(c.inbox) + c.wake() + }) + return nil +} + +func (c *memPacketConn) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +func (c *memPacketConn) SetReadDeadline(t time.Time) error { + if t.IsZero() { + c.readDeadline.Store(0) + } else { + c.readDeadline.Store(t.UnixNano()) + } + return nil +} + +func (c *memPacketConn) SetWriteDeadline(t time.Time) error { + if t.IsZero() { + c.writeDeadline.Store(0) + } else { + c.writeDeadline.Store(t.UnixNano()) + } + return nil +} + +func (c *memPacketConn) SetReadBuffer(int) error { return nil } +func (c *memPacketConn) SetWriteBuffer(int) error { return nil } + +func (c *memPacketConn) startDispatcher() { + c.dispatcherOnce.Do(func() { + if c.wakeCh == nil { + c.wakeCh = make(chan struct{}, 1) + } + heap.Init(&c.queue) + go c.dispatchLoop() + }) +} + +func (c *memPacketConn) enqueuePacket(pkt memPacket, when time.Time) { + if c.closed.Load() { + return + } + c.queueMu.Lock() + earliest := c.queue.Len() == 0 + if !earliest { + earliest = when.Before(c.queue[0].when) + } + heap.Push(&c.queue, scheduledPacket{when: when, pkt: pkt}) + c.queueMu.Unlock() + if earliest { + c.wake() + } +} + +func (c *memPacketConn) dispatchLoop() { + timer := time.NewTimer(time.Hour) + if !timer.Stop() { + <-timer.C + } + for { + c.queueMu.Lock() + if c.queue.Len() == 0 { + closed := c.closed.Load() + c.queueMu.Unlock() + if closed { + return + } + select { + case <-c.closeCh: + return + case <-c.wakeCh: + continue + } + } + next := c.queue[0] + wait := time.Until(next.when) + if wait <= 0 { + heap.Pop(&c.queue) + c.queueMu.Unlock() + c.deliver(next.pkt) + continue + } + c.queueMu.Unlock() + + resetTimer(timer, wait) + select { + case <-timer.C: + continue + case <-c.closeCh: + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + return + case <-c.wakeCh: + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + continue + } + } +} + +func (c *memPacketConn) deliver(pkt memPacket) { + if c.closed.Load() { + return + } + select { + case c.inbox <- pkt: + case <-c.closeCh: + default: + } +} + +func (c *memPacketConn) wake() { + if c.wakeCh == nil { + return + } + select { + case c.wakeCh <- struct{}{}: + default: + } +} + +func resetTimer(t *time.Timer, d time.Duration) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(d) +} diff --git a/adnl/rldp/queue.go b/adnl/rldp/queue.go new file mode 100644 index 00000000..fc7bba83 --- /dev/null +++ b/adnl/rldp/queue.go @@ -0,0 +1,47 @@ +package rldp + +type Queue struct { + ch chan *MessagePart +} + +func NewQueue(sz int) *Queue { + return &Queue{ + ch: make(chan *MessagePart, sz), + } +} + +func (q *Queue) Enqueue(m *MessagePart) { + for { + select { + case q.ch <- m: + // written + return + default: + } + + select { + case <-q.ch: + // not written, drop oldest + default: + } + } +} + +func (q *Queue) Dequeue() (*MessagePart, bool) { + select { + case m := <-q.ch: + return m, true + default: + return nil, false + } +} + +func (q *Queue) Drain() { + for { + select { + case <-q.ch: + default: + return + } + } +} diff --git a/adnl/rldp/queue_test.go b/adnl/rldp/queue_test.go new file mode 100644 index 00000000..f9c89ccf --- /dev/null +++ b/adnl/rldp/queue_test.go @@ -0,0 +1,203 @@ +package rldp + +import ( + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func mp(i uint32) *MessagePart { return &MessagePart{Seqno: i} } + +func TestQueue_Basic(t *testing.T) { + q := NewQueue(4) + + if _, ok := q.Dequeue(); ok { + t.Fatalf("queue must be empty initially") + } + + // put 3 + for i := uint32(0); i < 3; i++ { + q.Enqueue(mp(i)) + } + + // get 3 in order + for i := uint32(0); i < 3; i++ { + m, ok := q.Dequeue() + if !ok { + t.Fatalf("expected ok for i=%d", i) + } + + if m.Seqno != i { + t.Fatalf("want=%d got=%d", i, m.Seqno) + } + } + + // empty again + if _, ok := q.Dequeue(); ok { + t.Fatalf("queue must be empty") + } +} + +func TestQueue_OverwriteOldest(t *testing.T) { + q := NewQueue(4) + + for i := uint32(0); i < 6; i++ { + q.Enqueue(mp(i)) + } + + want := []uint32{2, 3, 4, 5} + for i, w := range want { + m, ok := q.Dequeue() + if !ok { + t.Fatalf("expected ok at i=%d", i) + } + if m == nil || m.Seqno != w { + t.Fatalf("want=%d got=%v at i=%d", w, m, i) + } + } + // пусто + if _, ok := q.Dequeue(); ok { + t.Fatalf("queue must be empty") + } +} + +func TestQueue_OverwriteInterleaved(t *testing.T) { + q := NewQueue(2) + + q.Enqueue(mp(0)) + q.Enqueue(mp(1)) + q.Enqueue(mp(2)) + q.Enqueue(mp(3)) + + // ожидаем 2,3 + m, ok := q.Dequeue() + if !ok || m.Seqno != 2 { + t.Fatalf("want=2 got=%v", m) + } + m, ok = q.Dequeue() + if !ok || m.Seqno != 3 { + t.Fatalf("want=3 got=%v", m) + } +} + +func TestQueue_DequeueEmpty(t *testing.T) { + q := NewQueue(1) + if m, ok := q.Dequeue(); ok || m != nil { + t.Fatalf("should be empty") + } +} + +func TestQueue_Concurrent_Stress_NoNilOnOk(t *testing.T) { + const ( + capacity = 256 + producers = 8 + perProducer = 50_000 + ) + + q := NewQueue(capacity) + + var wg sync.WaitGroup + wg.Add(producers) + + // consumer + var okCnt, nilOnOk int64 + done := make(chan struct{}) + go func() { + defer close(done) + stopAt := time.Now().Add(3 * time.Second) + for time.Now().Before(stopAt) { + if m, ok := q.Dequeue(); ok { + atomic.AddInt64(&okCnt, 1) + if m == nil { + atomic.AddInt64(&nilOnOk, 1) + } + } else { + runtime.Gosched() + } + } + }() + + // producers + for p := 0; p < producers; p++ { + pid := uint32(p) + go func(pid uint32) { + defer wg.Done() + for i := uint32(0); i < perProducer; i++ { + q.Enqueue(&MessagePart{Part: pid, Seqno: i}) + if (i & 1023) == 0 { + runtime.Gosched() + } + } + }(pid) + } + + wg.Wait() + <-done + + if n := atomic.LoadInt64(&nilOnOk); n != 0 { + t.Fatalf("got %d cases of (nil, ok=true) — must never happen", n) + } + if atomic.LoadInt64(&okCnt) == 0 { + t.Fatalf("consumer should have received some items") + } +} + +func TestQueue_NoDeadlockOnFull(t *testing.T) { + q := NewQueue(4) + + for i := uint32(0); i < 4; i++ { + q.Enqueue(mp(i)) + } + + stop := make(chan struct{}) + go func() { + defer close(stop) + for i := uint32(4); i < 20_000; i++ { + q.Enqueue(mp(i)) + if (i & 4095) == 0 { + runtime.Gosched() + } + } + }() + + timeout := time.After(500 * time.Millisecond) + read := 0 +loop: + for { + select { + case <-timeout: + break loop + default: + if _, ok := q.Dequeue(); ok { + read++ + } else { + runtime.Gosched() + } + } + } + + <-stop + if read == 0 { + t.Fatalf("expected to read something without deadlock") + } +} + +func TestQueue_OrderAfterOverwriteWindow(t *testing.T) { + q := NewQueue(8) + + for i := uint32(0); i < 20; i++ { + q.Enqueue(mp(i)) + } + + for want := uint32(12); want < 20; want++ { + m, ok := q.Dequeue() + if !ok || m == nil || m.Seqno != want { + t.Fatalf("want=%d got=%v (ok=%v)", want, m, ok) + } + } + if _, ok := q.Dequeue(); ok { + t.Fatalf("queue must be empty") + } +} diff --git a/adnl/rldp/rate.go b/adnl/rldp/rate.go deleted file mode 100644 index 8986dd32..00000000 --- a/adnl/rldp/rate.go +++ /dev/null @@ -1,230 +0,0 @@ -package rldp - -import ( - "math" - "sync/atomic" - "time" -) - -type AdaptiveRateOptions struct { - // samples to recalc - MinSamples int64 - - // delay = clamp(DelayBaseMs + rate/DelayPerRateDiv, DelayMinMs, DelayMaxMs). - DelayBaseMs int64 - DelayPerRateDiv int64 - DelayMinMs int64 - DelayMaxMs int64 - - // For EWMA - TargetLoss float64 - HighLoss float64 - - IncreaseFactor float64 // +6.25% => 0.0625 - DecreaseFactor float64 // -33% => 0.33 - MildDecreaseFactor float64 // -5% => 0.05 - - // if diff less, dont touch rate - Deadband float64 // 0.02 = 2% - - // smoothing - EWMAAlpha float64 // 0.1 - - MinRate int64 - MaxRate int64 // 0 = no limit - - IncreaseOnlyWhenTokensBelow float64 - - EnableSlowStart bool // fast rate scale at startup when enabled - SlowStartMultiplier float64 - SlowStartExitLoss float64 -} - -type AdaptiveRateController struct { - limiter *TokenBucket - opts AdaptiveRateOptions - - total atomic.Int64 - recv atomic.Int64 - samples atomic.Int64 - lastProc atomic.Int64 // unix ms - - lossEWMA atomic.Uint64 // float64 bits - - inSlowStart atomic.Bool -} - -func NewAdaptiveRateController(l *TokenBucket, o AdaptiveRateOptions) *AdaptiveRateController { - applyDefaults(&o) - - rc := &AdaptiveRateController{ - limiter: l, - opts: o, - } - if o.EnableSlowStart { - rc.inSlowStart.Store(true) - } - return rc -} - -func (rc *AdaptiveRateController) ObserveDelta(total, recv uint32) { - if total == 0 { - return - } - - rc.total.Add(int64(total)) - rc.recv.Add(int64(recv)) - samples := rc.samples.Add(1) - - nowMs := time.Now().UnixMilli() - rate := rc.limiter.GetRate() - - delay := rc.opts.DelayBaseMs - if rc.opts.DelayPerRateDiv > 0 { - delay += rate / rc.opts.DelayPerRateDiv - } - if delay < rc.opts.DelayMinMs { - delay = rc.opts.DelayMinMs - } else if delay > rc.opts.DelayMaxMs { - delay = rc.opts.DelayMaxMs - } - - last := rc.lastProc.Load() - if samples < rc.opts.MinSamples || last+delay > nowMs { - return - } - if !rc.lastProc.CompareAndSwap(last, nowMs) { - return - } - - totalN := rc.total.Swap(0) - recvN := rc.recv.Swap(0) - rc.samples.Store(0) - - if totalN <= 0 { - return - } - - loss := float64(totalN-recvN) / float64(totalN) - - prevEWMA := math.Float64frombits(rc.lossEWMA.Load()) - if prevEWMA == 0 { - prevEWMA = loss - } - ewma := prevEWMA*(1-rc.opts.EWMAAlpha) + loss*rc.opts.EWMAAlpha - rc.lossEWMA.Store(math.Float64bits(ewma)) - - newRate := rate - - if rc.inSlowStart.Load() { - // aggressive while loss is low - if ewma < rc.opts.TargetLoss { - m := rc.opts.SlowStartMultiplier - if m < 1.1 { - m = 2.0 - } - up := int64(float64(rate) * (m - 1)) - if up < 1 { - up = 1 - } - newRate = rate + up - } else { - // too high loss - rc.inSlowStart.Store(false) - } - - if rc.opts.SlowStartExitLoss > 0 && ewma > rc.opts.SlowStartExitLoss { - rc.inSlowStart.Store(false) - } - } - - if !rc.inSlowStart.Load() { - switch { - case ewma < rc.opts.TargetLoss: - if rc.opts.IncreaseOnlyWhenTokensBelow > 0 { - tokens := rc.limiter.GetTokensLeft() - threshold := int64(float64(rate) * rc.opts.IncreaseOnlyWhenTokensBelow) - if tokens < threshold { - newRate = rate + int64(float64(rate)*rc.opts.IncreaseFactor) - } - } else { - newRate = rate + int64(float64(rate)*rc.opts.IncreaseFactor) - } - case ewma > rc.opts.HighLoss: - newRate = rate - int64(float64(rate)*rc.opts.DecreaseFactor) - default: - newRate = rate - int64(float64(rate)*rc.opts.MildDecreaseFactor) - } - } - - if newRate < rc.opts.MinRate { - newRate = rc.opts.MinRate - } - if rc.opts.MaxRate > 0 && newRate > rc.opts.MaxRate { - newRate = rc.opts.MaxRate - } - - if rate > 0 { - diff := math.Abs(float64(newRate-rate)) / float64(rate) - if diff < rc.opts.Deadband { - return - } - } - - if newRate != rate { - rc.limiter.SetRate(newRate) - } -} - -func applyDefaults(o *AdaptiveRateOptions) { - if o.MinSamples == 0 { - o.MinSamples = 3 - } - if o.DelayBaseMs == 0 { - o.DelayBaseMs = 10 - } - if o.DelayPerRateDiv == 0 { - o.DelayPerRateDiv = 2000 - } - if o.DelayMinMs == 0 { - o.DelayMinMs = 10 - } - if o.DelayMaxMs == 0 { - o.DelayMaxMs = 500 - } - if o.TargetLoss == 0 { - o.TargetLoss = 0.02 - } - if o.HighLoss == 0 { - o.HighLoss = 0.10 - } - if o.IncreaseFactor == 0 { - o.IncreaseFactor = 0.0625 // +6.25% - } - if o.DecreaseFactor == 0 { - o.DecreaseFactor = 0.33 // -33% - } - if o.MildDecreaseFactor == 0 { - o.MildDecreaseFactor = 0.05 // -5% - } - if o.Deadband == 0 { - o.Deadband = 0.02 - } - if o.EWMAAlpha == 0 { - o.EWMAAlpha = 0.1 - } - if o.MinRate == 0 { - o.MinRate = 6000 - } - if o.IncreaseOnlyWhenTokensBelow == 0 { - o.IncreaseOnlyWhenTokensBelow = 0.0 - } - if o.EnableSlowStart { - if o.SlowStartMultiplier == 0 { - o.SlowStartMultiplier = 2.0 - } - if o.SlowStartExitLoss == 0 { - o.SlowStartExitLoss = o.TargetLoss * 2 - } - } -} diff --git a/adnl/rldp/rldp.go b/adnl/rldp/rldp.go index 9f6ba109..257035e8 100644 --- a/adnl/rldp/rldp.go +++ b/adnl/rldp/rldp.go @@ -63,7 +63,7 @@ type CompleteV2 struct { type MessagePart struct { TransferID []byte // `tl:"int256"` - FecType any // `tl:"struct boxed [fec.roundRobin,fec.raptorQ,fec.online]"` + FecType FEC // `tl:"struct boxed [fec.roundRobin,fec.raptorQ,fec.online]"` Part uint32 // `tl:"int"` TotalSize uint64 // `tl:"long"` Seqno uint32 // `tl:"int"` @@ -78,12 +78,17 @@ func (m *MessagePart) Parse(data []byte) ([]byte, error) { transfer := make([]byte, 32) copy(transfer, data) - var fec FECRaptorQ - data, err := tl.Parse(&fec, data[32:], true) + var fecAny any + data, err := tl.Parse(&fecAny, data[32:], true) if err != nil { return nil, err } + fec, ok := fecAny.(FEC) + if !ok { + return nil, errors.New("invalid fec type") + } + if len(data) < 20 { return nil, errors.New("message part is too short") } @@ -110,37 +115,110 @@ func (m *MessagePart) Parse(data []byte) ([]byte, error) { func (m *MessagePart) Serialize(buf *bytes.Buffer) error { switch m.FecType.(type) { case FECRaptorQ: - if len(m.TransferID) == 0 { - buf.Write(make([]byte, 32)) - } else if len(m.TransferID) != 32 { - return errors.New("invalid transfer id") - } else { - buf.Write(m.TransferID) - } - - _, err := tl.Serialize(m.FecType, true, buf) - if err != nil { - return err - } + case FECRoundRobin: default: return errors.New("invalid fec type") } + if len(m.TransferID) == 0 { + buf.Write(make([]byte, 32)) + } else if len(m.TransferID) != 32 { + return errors.New("invalid transfer id") + } else { + buf.Write(m.TransferID) + } + + _, err := tl.Serialize(m.FecType, true, buf) + if err != nil { + return err + } + tmp := make([]byte, 16) binary.LittleEndian.PutUint32(tmp, m.Part) binary.LittleEndian.PutUint64(tmp[4:], m.TotalSize) binary.LittleEndian.PutUint32(tmp[12:], m.Seqno) buf.Write(tmp) - tl.ToBytesToBuffer(buf, m.Data) - return nil + return tl.ToBytesToBuffer(buf, m.Data) } type MessagePartV2 struct { TransferID []byte `tl:"int256"` - FecType any `tl:"struct boxed [fec.roundRobin,fec.raptorQ,fec.online]"` + FecType FEC `tl:"struct boxed [fec.roundRobin,fec.raptorQ,fec.online]"` Part uint32 `tl:"int"` TotalSize uint64 `tl:"long"` Seqno uint32 `tl:"int"` Data []byte `tl:"bytes"` } + +func (m *MessagePartV2) Parse(data []byte) ([]byte, error) { + if len(data) < 56 { + return nil, errors.New("message part is too short") + } + + transfer := make([]byte, 32) + copy(transfer, data) + + var fecAny any + data, err := tl.Parse(&fecAny, data[32:], true) + if err != nil { + return nil, err + } + + fec, ok := fecAny.(FEC) + if !ok { + return nil, errors.New("invalid fec type") + } + + if len(data) < 20 { + return nil, errors.New("message part is too short") + } + + part := binary.LittleEndian.Uint32(data) + size := binary.LittleEndian.Uint64(data[4:]) + seq := binary.LittleEndian.Uint32(data[12:]) + + slc, data, err := tl.FromBytes(data[16:]) + if err != nil { + return nil, fmt.Errorf("tl.FromBytes: %v", err) + } + + m.TransferID = transfer + m.FecType = fec + m.Part = part + m.TotalSize = size + m.Seqno = seq + m.Data = slc + + return data, nil +} + +func (m *MessagePartV2) Serialize(buf *bytes.Buffer) error { + switch m.FecType.(type) { + case FECRaptorQ: + case FECRoundRobin: + default: + return errors.New("invalid fec type") + } + + if len(m.TransferID) == 0 { + buf.Write(make([]byte, 32)) + } else if len(m.TransferID) != 32 { + return errors.New("invalid transfer id") + } else { + buf.Write(m.TransferID) + } + + _, err := tl.Serialize(m.FecType, true, buf) + if err != nil { + return err + } + + tmp := make([]byte, 16) + binary.LittleEndian.PutUint32(tmp, m.Part) + binary.LittleEndian.PutUint64(tmp[4:], m.TotalSize) + binary.LittleEndian.PutUint32(tmp[12:], m.Seqno) + buf.Write(tmp) + + return tl.ToBytesToBuffer(buf, m.Data) +} diff --git a/adnl/rldp/roundrobin/coder.go b/adnl/rldp/roundrobin/coder.go new file mode 100644 index 00000000..c9810efe --- /dev/null +++ b/adnl/rldp/roundrobin/coder.go @@ -0,0 +1,99 @@ +package roundrobin + +import "errors" + +type Encoder struct { + data []byte + symbolSize uint32 + symbolsCount uint32 +} + +type Decoder struct { + data []byte + mask []bool + left uint32 + symbolSize uint32 + symbolsCount uint32 +} + +func NewEncoder(data []byte, maxSymbolSize uint32) (*Encoder, error) { + syms := (len(data) + int(maxSymbolSize) - 1) / int(maxSymbolSize) + if syms == 0 { + return nil, errors.New("data must be non-empty") + } + + return &Encoder{ + data: data, + symbolSize: maxSymbolSize, + symbolsCount: uint32(syms), + }, nil +} + +func (e *Encoder) GenSymbol(id uint32) []byte { + if e.symbolsCount == 0 { + return nil + } + pos := id % e.symbolsCount + offset := pos * e.symbolSize + end := offset + e.symbolSize + + out := make([]byte, e.symbolSize) + if int(offset) < len(e.data) { + if int(end) > len(e.data) { + end = uint32(len(e.data)) + } + copy(out, e.data[offset:end]) + } + return out +} + +func NewDecoder(symbolSize uint32, dataSize uint32) (*Decoder, error) { + syms := (dataSize + symbolSize - 1) / symbolSize + if syms == 0 { + return nil, errors.New("dataSize must be > 0") + } + + return &Decoder{ + data: make([]byte, dataSize), + mask: make([]bool, syms), + left: syms, + symbolSize: symbolSize, + symbolsCount: syms, + }, nil +} + +func (d *Decoder) AddSymbol(id uint32, sym []byte) (bool, error) { + if uint32(len(sym)) != d.symbolSize { + return false, errors.New("invalid symbol length") + } + + if d.symbolsCount == 0 { + return false, errors.New("decoder not initialized") + } + + pos := id % d.symbolsCount + idx := int(pos) + if d.mask[idx] { + return d.left == 0, nil + } + + offset := idx * int(d.symbolSize) + end := offset + int(d.symbolSize) + if offset < len(d.data) { + if end > len(d.data) { + end = len(d.data) + } + copy(d.data[offset:end], sym[:end-offset]) + } + + d.mask[idx] = true + d.left-- + return d.left == 0, nil +} + +func (d *Decoder) Decode() (bool, []byte, error) { + if d.left != 0 { + return false, nil, errors.New("not ready") + } + return true, d.data, nil +} diff --git a/adnl/rldp/roundrobin/coder_test.go b/adnl/rldp/roundrobin/coder_test.go new file mode 100644 index 00000000..496238b5 --- /dev/null +++ b/adnl/rldp/roundrobin/coder_test.go @@ -0,0 +1,222 @@ +package roundrobin + +import ( + "crypto/rand" + mrand "math/rand" + "testing" + "time" +) + +func genRandomBytes(t *testing.T, n int) []byte { + t.Helper() + if n == 0 { + return nil + } + b := make([]byte, n) + if _, err := rand.Read(b); err != nil { + t.Fatalf("rand.Read: %v", err) + } + return b +} + +func TestNewEncoder_EmptyData(t *testing.T) { + _, err := NewEncoder(nil, 256) + if err == nil { + t.Fatalf("expected error for empty data, got nil") + } +} + +func TestNewDecoder_BadParams(t *testing.T) { + _, err := NewDecoder(256, 0) + if err == nil { + t.Fatalf("expected error for dataSize=0, got nil") + } +} + +func TestRoundTrip_ExactMultiple(t *testing.T) { + const symbolSize = 256 + const dataSize = 4096 + data := genRandomBytes(t, dataSize) + + enc, err := NewEncoder(data, symbolSize) + if err != nil { + t.Fatalf("NewEncoder: %v", err) + } + dec, err := NewDecoder(symbolSize, uint32(len(data))) + if err != nil { + t.Fatalf("NewDecoder: %v", err) + } + + for i := uint32(0); i < enc.symbolsCount; i++ { + s := enc.GenSymbol(i) + if uint32(len(s)) != symbolSize { + t.Fatalf("symbol len mismatch: got %d, want %d", len(s), symbolSize) + } + } + + ids := make([]uint32, enc.symbolsCount) + for i := range ids { + ids[i] = uint32(i) + } + r := mrand.New(mrand.NewSource(1)) + r.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) + + complete := false + for idx, id := range ids { + sym := enc.GenSymbol(id) + done, err := dec.AddSymbol(id, sym) + if err != nil { + t.Fatalf("AddSymbol #%d (id=%d): %v", idx, id, err) + } + if done && idx != len(ids)-1 { + t.Fatalf("completed too early at #%d", idx) + } + complete = done + } + if !complete { + t.Fatalf("not completed after feeding all symbols") + } + + ready, out, err := dec.Decode() + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if !ready { + t.Fatalf("Decode not ready") + } + if len(out) != len(data) { + t.Fatalf("decoded length mismatch: got %d, want %d", len(out), len(data)) + } + if string(out) != string(data) { + t.Fatalf("decoded data mismatch") + } +} + +func TestRoundTrip_WithTailRemainder(t *testing.T) { + const symbolSize = 256 + const dataSize = 1000 + data := genRandomBytes(t, dataSize) + + enc, err := NewEncoder(data, symbolSize) + if err != nil { + t.Fatalf("NewEncoder: %v", err) + } + + lastID := enc.symbolsCount - 1 + last := enc.GenSymbol(lastID) + if len(last) != symbolSize { + t.Fatalf("last symbol len: got %d, want %d", len(last), symbolSize) + } + + dec, err := NewDecoder(symbolSize, uint32(len(data))) + if err != nil { + t.Fatalf("NewDecoder: %v", err) + } + + ids := make([]uint32, enc.symbolsCount) + for i := range ids { + ids[i] = uint32(i) + } + r := mrand.New(mrand.NewSource(time.Now().UnixNano())) + r.Shuffle(len(ids), func(i, j int) { ids[i], ids[j] = ids[j], ids[i] }) + + for _, id := range ids { + sym := enc.GenSymbol(id) + if _, err := dec.AddSymbol(id, sym); err != nil { + t.Fatalf("AddSymbol id=%d: %v", id, err) + } + if _, err := dec.AddSymbol(id, sym); err != nil { + t.Fatalf("AddSymbol duplicate id=%d: %v", id, err) + } + } + + ready, out, err := dec.Decode() + if err != nil { + t.Fatalf("Decode error: %v", err) + } + if !ready { + t.Fatalf("Decode not ready") + } + if len(out) != len(data) { + t.Fatalf("decoded length mismatch: got %d, want %d", len(out), len(data)) + } + if string(out) != string(data) { + t.Fatalf("decoded data mismatch") + } +} + +func TestInvalidSymbolLength(t *testing.T) { + const symbolSize = 128 + const dataSize = 1024 + data := genRandomBytes(t, dataSize) + + enc, err := NewEncoder(data, symbolSize) + if err != nil { + t.Fatalf("NewEncoder: %v", err) + } + dec, err := NewDecoder(symbolSize, uint32(len(data))) + if err != nil { + t.Fatalf("NewDecoder: %v", err) + } + + id := uint32(0) + okSym := enc.GenSymbol(id) + if len(okSym) != symbolSize { + t.Fatalf("GenSymbol produced wrong length: %d", len(okSym)) + } + + badShort := okSym[:symbolSize-1] + if _, err := dec.AddSymbol(id, badShort); err == nil { + t.Fatalf("expected error for short symbol, got nil") + } + + badLong := append(okSym, 0xFF) + if _, err := dec.AddSymbol(id, badLong); err == nil { + t.Fatalf("expected error for long symbol, got nil") + } +} + +func TestDecode_NotReady(t *testing.T) { + const symbolSize = 64 + const dataSize = 1000 + data := genRandomBytes(t, dataSize) + + enc, err := NewEncoder(data, symbolSize) + if err != nil { + t.Fatalf("NewEncoder: %v", err) + } + dec, err := NewDecoder(symbolSize, uint32(len(data))) + if err != nil { + t.Fatalf("NewDecoder: %v", err) + } + + for id := uint32(0); id < enc.symbolsCount-1; id++ { + if _, err := dec.AddSymbol(id, enc.GenSymbol(id)); err != nil { + t.Fatalf("AddSymbol: %v", err) + } + } + + if ready, _, err := dec.Decode(); err == nil || ready { + t.Fatalf("expected not ready error, got ready=%v err=%v", ready, err) + } +} + +func TestGenSymbol_ModuloBehavior(t *testing.T) { + const symbolSize = 128 + const dataSize = 777 + data := genRandomBytes(t, dataSize) + + enc, err := NewEncoder(data, symbolSize) + if err != nil { + t.Fatalf("NewEncoder: %v", err) + } + + if enc.symbolsCount < 2 { + t.Skip("need at least 2 symbols to test modulo behavior") + } + ref := enc.GenSymbol(1) + same := enc.GenSymbol(enc.symbolsCount + 1) + if string(ref) != string(same) { + t.Fatalf("GenSymbol modulo mismatch: got different payloads for ids 1 and symbolsCount+1") + } +} diff --git a/go.mod b/go.mod index d423471a..0a5e16b6 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,6 @@ toolchain go1.24.3 require ( filippo.io/edwards25519 v1.1.0 - github.com/xssnick/raptorq v1.2.0 + github.com/xssnick/raptorq v1.3.0 golang.org/x/crypto v0.42.0 ) diff --git a/go.sum b/go.sum index 8e5e7410..e8b019ff 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/xssnick/raptorq v1.2.0 h1:ts8yjB3Xns3GS4V7/xQLx9lYLAlSwzlvJjGDLscZCDA= -github.com/xssnick/raptorq v1.2.0/go.mod h1:kgEVVsZv2hP+IeV7C7985KIFsDdvYq2ARW234SBA9Q4= +github.com/xssnick/raptorq v1.3.0 h1:3GoaySKMg/i8rbjhIuqjxpTTO2l3Gs2/Gh7k3GAjvGo= +github.com/xssnick/raptorq v1.3.0/go.mod h1:kgEVVsZv2hP+IeV7C7985KIFsDdvYq2ARW234SBA9Q4= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= diff --git a/tl/bytes.go b/tl/bytes.go index 6f2269dd..acc9a997 100644 --- a/tl/bytes.go +++ b/tl/bytes.go @@ -7,39 +7,21 @@ import ( "fmt" ) -func ToBytes(buf []byte) []byte { - var data = make([]byte, 0, ((len(buf)+4)/4+1)*4) - - // store buf length - if len(buf) >= 0xFE { - ln := make([]byte, 4) - binary.LittleEndian.PutUint32(ln, uint32(len(buf)<<8)|0xFE) - data = append(data, ln...) - } else { - data = append(data, byte(len(buf))) - } - - data = append(data, buf...) - - // adjust actual length to fit % 4 = 0 - if round := len(data) % 4; round != 0 { - data = append(data, make([]byte, 4-round)...) - } - - return data -} - -func ToBytesToBuffer(buf *bytes.Buffer, data []byte) { +func ToBytesToBuffer(buf *bytes.Buffer, data []byte) error { if len(data) == 0 { // fast path for empty slice buf.Write(make([]byte, 4)) - return + return nil } prevLen := buf.Len() // store buf length if len(data) >= 0xFE { + if len(data) >= 1<<24 { + return fmt.Errorf("too big bytes len, TL bytes array limited by 1<<24") + } + ln := make([]byte, 4) binary.LittleEndian.PutUint32(ln, uint32(len(data)<<8)|0xFE) buf.Write(ln) @@ -55,6 +37,7 @@ func ToBytesToBuffer(buf *bytes.Buffer, data []byte) { buf.WriteByte(0) } } + return nil } func RemapBufferAsSlice(buf *bytes.Buffer, from int) { diff --git a/tl/bytes_test.go b/tl/bytes_test.go index 334e4541..424b3967 100644 --- a/tl/bytes_test.go +++ b/tl/bytes_test.go @@ -7,13 +7,18 @@ import ( func TestTLBytes(t *testing.T) { buf := []byte{0xFF, 0xAA} - if !bytes.Equal(append([]byte{2}, append(buf, 0)...), ToBytes(buf)) { + b := &bytes.Buffer{} + ToBytesToBuffer(b, buf) + + if !bytes.Equal(append([]byte{2}, append(buf, 0)...), b.Bytes()) { t.Fatal("not equal small") return } buf = []byte{0xFF, 0xAA, 0xCC} - if !bytes.Equal(append([]byte{3}, buf...), ToBytes(buf)) { + b.Reset() + ToBytesToBuffer(b, buf) + if !bytes.Equal(append([]byte{3}, buf...), b.Bytes()) { t.Fatal("not equal small 2") return } @@ -23,8 +28,11 @@ func TestTLBytes(t *testing.T) { buf = append(buf, 0xFF) } + b.Reset() + ToBytesToBuffer(b, buf) + // corner case + round to 4 - if !bytes.Equal(append([]byte{0xFE, 0xFE, 0x00, 0x00}, append(buf, 0x00, 0x00)...), ToBytes(buf)) { + if !bytes.Equal(append([]byte{0xFE, 0xFE, 0x00, 0x00}, append(buf, 0x00, 0x00)...), b.Bytes()) { t.Fatal("not equal middle") return } @@ -33,9 +41,16 @@ func TestTLBytes(t *testing.T) { for i := 0; i < 1217; i++ { buf = append(buf, byte(i%256)) } + b.Reset() + ToBytesToBuffer(b, buf) - if !bytes.Equal(append([]byte{0xFE, 0xC1, 0x04, 0x00}, append(buf, 0x00, 0x00, 0x00)...), ToBytes(buf)) { + if !bytes.Equal(append([]byte{0xFE, 0xC1, 0x04, 0x00}, append(buf, 0x00, 0x00, 0x00)...), b.Bytes()) { t.Fatal("not equal big") return } + + b.Reset() + if err := ToBytesToBuffer(b, make([]byte, 1<<24)); err == nil { + t.Fatal("should be error") + } } diff --git a/tl/loader_test.go b/tl/loader_test.go index 467c4637..ba791f94 100644 --- a/tl/loader_test.go +++ b/tl/loader_test.go @@ -97,6 +97,7 @@ func init() { Register(TestInner{}, "in 123") // root 777 Register(TestTL{}, "root 222") Register(TestManual{}, "manual val") + Register(AnyBig{}, "anybig") buf := make([]byte, 4) binary.LittleEndian.PutUint32(buf, RegisterWithFabric(Small{}, "small 123", func() reflect.Value { @@ -173,3 +174,21 @@ func BenchmarkParse(b *testing.B) { } _ = tst } + +type AnyBig struct { + Data [][]byte `tl:"vector bytes"` +} + +func BenchmarkSerialize(b *testing.B) { + v := AnyBig{} + for i := 0; i < 100; i++ { + v.Data = append(v.Data, make([]byte, 1<<20)) + } + + for i := 0; i < b.N; i++ { + _, err := Serialize(&v, true) + if err != nil { + panic(err) + } + } +} diff --git a/tl/precompile.go b/tl/precompile.go index 2a844a08..fe6d231c 100644 --- a/tl/precompile.go +++ b/tl/precompile.go @@ -802,9 +802,13 @@ func executeSerialize(buf *bytes.Buffer, startPtr uintptr, si *structInfo) error binary.LittleEndian.PutUint32(tmp, flags) buf.Write(tmp) case _ExecuteTypeString: - ToBytesToBuffer(buf, []byte(*(*string)(ptr))) + if err := ToBytesToBuffer(buf, []byte(*(*string)(ptr))); err != nil { + return fmt.Errorf("failed to serialize string field %s: %w", field.String(), err) + } case _ExecuteTypeBytes: - ToBytesToBuffer(buf, *(*[]byte)(ptr)) + if err := ToBytesToBuffer(buf, *(*[]byte)(ptr)); err != nil { + return fmt.Errorf("failed to serialize bytes field %s: %w", field.String(), err) + } case _ExecuteTypeInt256: if bts := *(*[]byte)(ptr); len(bts) == 32 { buf.Write(*(*[]byte)(ptr)) @@ -873,26 +877,31 @@ func executeSerialize(buf *bytes.Buffer, startPtr uintptr, si *structInfo) error c := *(**cell.Cell)(ptr) if c == nil { if field.meta.(bool) { - ToBytesToBuffer(buf, nil) + _ = ToBytesToBuffer(buf, nil) break } return fmt.Errorf("nil cell is not allowed in field %s", field.String()) } - ToBytesToBuffer(buf, (*(**cell.Cell)(ptr)).ToBOCWithFlags(false)) + + if err := ToBytesToBuffer(buf, (*(**cell.Cell)(ptr)).ToBOCWithFlags(false)); err != nil { + return fmt.Errorf("failed to serialize cell field %s: %w", field.String(), err) + } case _ExecuteTypeSliceCell: c := *(*[]*cell.Cell)(ptr) flag := field.meta.(uint32) num := flag & 0x7FFFFFFF if len(c) == 0 && flag&(1<<31) != 0 { - ToBytesToBuffer(buf, nil) + _ = ToBytesToBuffer(buf, nil) break } if num > 0 && uint32(len(c)) != num { return fmt.Errorf("incorrect cells len %d in field %s", len(c), field.String()) } - ToBytesToBuffer(buf, cell.ToBOCWithFlags(c, false)) + if err := ToBytesToBuffer(buf, cell.ToBOCWithFlags(c, false)); err != nil { + return fmt.Errorf("failed to serialize slice cell field %s: %w", field.String(), err) + } case _ExecuteTypeStruct: info := field.structInfo structFlags := field.meta.(uint32) diff --git a/ton/transactions.go b/ton/transactions.go index 5c965950..f2ccd323 100644 --- a/ton/transactions.go +++ b/ton/transactions.go @@ -362,8 +362,6 @@ func (c *APIClient) findLastTransactionByHash(ctx context.Context, addr *address return transaction, nil } } - - continue } else { if transaction.IO.In == nil { continue @@ -378,8 +376,6 @@ func (c *APIClient) findLastTransactionByHash(ctx context.Context, addr *address return transaction, nil } } - - return transaction, nil } scanned += 15 diff --git a/ton/wallet/wallet.go b/ton/wallet/wallet.go index cce294b0..7796e194 100644 --- a/ton/wallet/wallet.go +++ b/ton/wallet/wallet.go @@ -180,6 +180,7 @@ func FromPrivateKeyWithOptions(key ed25519.PrivateKey, version VersionConfig, op append([]Option{WithPrivateKey(key)}, options...)...) } +// Deprecated: use FromPubKeyWithOptions(publicKey, version, WithSigner(signer)) func FromSigner(api TonAPI, publicKey ed25519.PublicKey, version VersionConfig, signer Signer) (*Wallet, error) { return newWallet( publicKey,