Thanks to visit codestin.com
Credit goes to n0-computer.github.io

iroh_blobs/
provider.rs

1//! The low level server side API
2//!
3//! Note that while using this API directly is fine, the standard way
4//! to provide data is to just register a [`crate::BlobsProtocol`] protocol
5//! handler with an [`iroh::Endpoint`](iroh::protocol::Router).
6use std::{fmt::Debug, future::Future, io, time::Duration};
7
8use bao_tree::ChunkRanges;
9use iroh::endpoint::{self, ConnectionError, VarInt};
10use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
11use n0_error::{e, stack_error, Result};
12use n0_future::{time::Instant, StreamExt};
13use serde::{Deserialize, Serialize};
14use tokio::select;
15use tracing::{debug, debug_span, Instrument};
16
17use crate::{
18    api::{
19        blobs::{Bitfield, WriteProgress},
20        ExportBaoError, ExportBaoResult, RequestError, Store,
21    },
22    hashseq::HashSeq,
23    protocol::{
24        GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
25    },
26    provider::events::{
27        ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
28        RequestTracker,
29    },
30    util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
31    Hash,
32};
33pub mod events;
34use events::EventSender;
35
36type DefaultReader = iroh::endpoint::RecvStream;
37type DefaultWriter = iroh::endpoint::SendStream;
38
39/// Statistics about a successful or failed transfer.
40#[derive(Debug, Serialize, Deserialize)]
41pub struct TransferStats {
42    /// The number of bytes sent that are part of the payload.
43    pub payload_bytes_sent: u64,
44    /// The number of bytes sent that are not part of the payload.
45    ///
46    /// Hash pairs and the initial size header.
47    pub other_bytes_sent: u64,
48    /// The number of bytes read from the stream.
49    ///
50    /// In most cases this is just the request, for push requests this is
51    /// request, size header and hash pairs.
52    pub other_bytes_read: u64,
53    /// Total duration from reading the request to transfer completed.
54    pub duration: Duration,
55}
56
57/// A pair of [`SendStream`] and [`RecvStream`] with additional context data.
58#[derive(Debug)]
59pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
60    t0: Instant,
61    connection_id: u64,
62    reader: R,
63    writer: W,
64    other_bytes_read: u64,
65    events: EventSender,
66}
67
68impl StreamPair {
69    pub async fn accept(
70        conn: &endpoint::Connection,
71        events: EventSender,
72    ) -> Result<Self, ConnectionError> {
73        let (writer, reader) = conn.accept_bi().await?;
74        Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
75    }
76}
77
78impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
79    pub fn stream_id(&self) -> u64 {
80        self.reader.id()
81    }
82
83    pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
84        Self {
85            t0: Instant::now(),
86            connection_id,
87            reader,
88            writer,
89            other_bytes_read: 0,
90            events,
91        }
92    }
93
94    /// Read the request.
95    ///
96    /// Will fail if there is an error while reading, or if no valid request is sent.
97    ///
98    /// This will read exactly the number of bytes needed for the request, and
99    /// leave the rest of the stream for the caller to read.
100    ///
101    /// It is up to the caller do decide if there should be more data.
102    pub async fn read_request(&mut self) -> Result<Request> {
103        let (res, size) = Request::read_async(&mut self.reader).await?;
104        self.other_bytes_read += size as u64;
105        Ok(res)
106    }
107
108    /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id
109    pub async fn into_writer(
110        mut self,
111        tracker: RequestTracker,
112    ) -> Result<ProgressWriter<W>, io::Error> {
113        self.reader.expect_eof().await?;
114        drop(self.reader);
115        Ok(ProgressWriter::new(
116            self.writer,
117            WriterContext {
118                t0: self.t0,
119                other_bytes_read: self.other_bytes_read,
120                payload_bytes_written: 0,
121                other_bytes_written: 0,
122                tracker,
123            },
124        ))
125    }
126
127    pub async fn into_reader(
128        mut self,
129        tracker: RequestTracker,
130    ) -> Result<ProgressReader<R>, io::Error> {
131        self.writer.sync().await?;
132        drop(self.writer);
133        Ok(ProgressReader {
134            inner: self.reader,
135            context: ReaderContext {
136                t0: self.t0,
137                other_bytes_read: self.other_bytes_read,
138                tracker,
139            },
140        })
141    }
142
143    pub async fn get_request(
144        &self,
145        f: impl FnOnce() -> GetRequest,
146    ) -> Result<RequestTracker, ProgressError> {
147        self.events
148            .request(f, self.connection_id, self.reader.id())
149            .await
150    }
151
152    pub async fn get_many_request(
153        &self,
154        f: impl FnOnce() -> GetManyRequest,
155    ) -> Result<RequestTracker, ProgressError> {
156        self.events
157            .request(f, self.connection_id, self.reader.id())
158            .await
159    }
160
161    pub async fn push_request(
162        &self,
163        f: impl FnOnce() -> PushRequest,
164    ) -> Result<RequestTracker, ProgressError> {
165        self.events
166            .request(f, self.connection_id, self.reader.id())
167            .await
168    }
169
170    pub async fn observe_request(
171        &self,
172        f: impl FnOnce() -> ObserveRequest,
173    ) -> Result<RequestTracker, ProgressError> {
174        self.events
175            .request(f, self.connection_id, self.reader.id())
176            .await
177    }
178
179    pub fn stats(&self) -> TransferStats {
180        TransferStats {
181            payload_bytes_sent: 0,
182            other_bytes_sent: 0,
183            other_bytes_read: self.other_bytes_read,
184            duration: self.t0.elapsed(),
185        }
186    }
187}
188
189#[derive(Debug)]
190struct ReaderContext {
191    /// The start time of the transfer
192    t0: Instant,
193    /// The number of bytes read from the stream
194    other_bytes_read: u64,
195    /// Progress tracking for the request
196    tracker: RequestTracker,
197}
198
199impl ReaderContext {
200    fn stats(&self) -> TransferStats {
201        TransferStats {
202            payload_bytes_sent: 0,
203            other_bytes_sent: 0,
204            other_bytes_read: self.other_bytes_read,
205            duration: self.t0.elapsed(),
206        }
207    }
208}
209
210#[derive(Debug)]
211pub(crate) struct WriterContext {
212    /// The start time of the transfer
213    t0: Instant,
214    /// The number of bytes read from the stream
215    other_bytes_read: u64,
216    /// The number of payload bytes written to the stream
217    payload_bytes_written: u64,
218    /// The number of bytes written that are not part of the payload
219    other_bytes_written: u64,
220    /// Way to report progress
221    tracker: RequestTracker,
222}
223
224impl WriterContext {
225    fn stats(&self) -> TransferStats {
226        TransferStats {
227            payload_bytes_sent: self.payload_bytes_written,
228            other_bytes_sent: self.other_bytes_written,
229            other_bytes_read: self.other_bytes_read,
230            duration: self.t0.elapsed(),
231        }
232    }
233}
234
235impl WriteProgress for WriterContext {
236    async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
237        let len = len as u64;
238        let end_offset = offset + len;
239        self.payload_bytes_written += len;
240        self.tracker.transfer_progress(len, end_offset).await
241    }
242
243    fn log_other_write(&mut self, len: usize) {
244        self.other_bytes_written += len as u64;
245    }
246
247    async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
248        self.tracker.transfer_started(index, hash, size).await.ok();
249    }
250}
251
252/// Wrapper for a [`quinn::SendStream`] with additional per request information.
253#[derive(Debug)]
254pub struct ProgressWriter<W: SendStream = DefaultWriter> {
255    /// The quinn::SendStream to write to
256    pub inner: W,
257    pub(crate) context: WriterContext,
258}
259
260impl<W: SendStream> ProgressWriter<W> {
261    fn new(inner: W, context: WriterContext) -> Self {
262        Self { inner, context }
263    }
264
265    async fn transfer_aborted(&self) {
266        self.context
267            .tracker
268            .transfer_aborted(|| Box::new(self.context.stats()))
269            .await
270            .ok();
271    }
272
273    async fn transfer_completed(&self) {
274        self.context
275            .tracker
276            .transfer_completed(|| Box::new(self.context.stats()))
277            .await
278            .ok();
279    }
280}
281
282/// Handle a single connection.
283pub async fn handle_connection(
284    connection: endpoint::Connection,
285    store: Store,
286    progress: EventSender,
287) {
288    let connection_id = connection.stable_id() as u64;
289    let span = debug_span!("connection", connection_id);
290    async move {
291        if let Err(cause) = progress
292            .client_connected(|| ClientConnected {
293                connection_id,
294                endpoint_id: Some(connection.remote_id()),
295            })
296            .await
297        {
298            connection.close(cause.code(), cause.reason());
299            debug!("closing connection: {cause}");
300            return;
301        }
302        while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
303            let span = debug_span!("stream", stream_id = %pair.stream_id());
304            let store = store.clone();
305            n0_future::task::spawn(handle_stream(pair, store).instrument(span));
306        }
307        progress
308            .connection_closed(|| ConnectionClosed { connection_id })
309            .await
310            .ok();
311    }
312    .instrument(span)
313    .await
314}
315
316/// Describes how to handle errors for a stream.
317pub trait ErrorHandler {
318    type W: AsyncStreamWriter;
319    type R: AsyncStreamReader;
320    fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
321    fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
322}
323
324async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
325    pair: &mut StreamPair<R, W>,
326    r: Result<T, E>,
327) -> Result<T, E> {
328    match r {
329        Ok(x) => Ok(x),
330        Err(e) => {
331            pair.writer.reset(e.code()).ok();
332            Err(e)
333        }
334    }
335}
336async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
337    writer: &mut ProgressWriter<W>,
338    r: Result<T, E>,
339) -> Result<T, E> {
340    match r {
341        Ok(x) => {
342            writer.transfer_completed().await;
343            Ok(x)
344        }
345        Err(e) => {
346            writer.inner.reset(e.code()).ok();
347            writer.transfer_aborted().await;
348            Err(e)
349        }
350    }
351}
352async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
353    reader: &mut ProgressReader<R>,
354    r: Result<T, E>,
355) -> Result<T, E> {
356    match r {
357        Ok(x) => {
358            reader.transfer_completed().await;
359            Ok(x)
360        }
361        Err(e) => {
362            reader.inner.stop(e.code()).ok();
363            reader.transfer_aborted().await;
364            Err(e)
365        }
366    }
367}
368
369pub async fn handle_stream<R: RecvStream, W: SendStream>(
370    mut pair: StreamPair<R, W>,
371    store: Store,
372) -> n0_error::Result<()> {
373    let request = pair.read_request().await?;
374    match request {
375        Request::Get(request) => handle_get(pair, store, request).await?,
376        Request::GetMany(request) => handle_get_many(pair, store, request).await?,
377        Request::Observe(request) => handle_observe(pair, store, request).await?,
378        Request::Push(request) => handle_push(pair, store, request).await?,
379        _ => {}
380    }
381    Ok(())
382}
383
384#[stack_error(derive, add_meta, from_sources)]
385pub enum HandleGetError {
386    #[error(transparent)]
387    ExportBao {
388        #[error(std_err)]
389        source: ExportBaoError,
390    },
391    #[error("Invalid hash sequence")]
392    InvalidHashSeq {},
393    #[error("Invalid offset")]
394    InvalidOffset {},
395}
396
397impl HasErrorCode for HandleGetError {
398    fn code(&self) -> VarInt {
399        match self {
400            HandleGetError::ExportBao {
401                source: ExportBaoError::ClientError { source, .. },
402                ..
403            } => source.code(),
404            HandleGetError::InvalidHashSeq { .. } => ERR_INTERNAL,
405            HandleGetError::InvalidOffset { .. } => ERR_INTERNAL,
406            _ => ERR_INTERNAL,
407        }
408    }
409}
410
411/// Handle a single get request.
412///
413/// Requires a database, the request, and a writer.
414async fn handle_get_impl<W: SendStream>(
415    store: Store,
416    request: GetRequest,
417    writer: &mut ProgressWriter<W>,
418) -> Result<(), HandleGetError> {
419    let hash = request.hash;
420    debug!(%hash, "get received request");
421    let mut hash_seq = None;
422    for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
423        if offset == 0 {
424            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
425        } else {
426            // todo: this assumes that 1. the hashseq is complete and 2. it is
427            // small enough to fit in memory.
428            //
429            // This should really read the hashseq from the store in chunks,
430            // only where needed, so we can deal with holes and large hashseqs.
431            let hash_seq = match &hash_seq {
432                Some(b) => b,
433                None => {
434                    let bytes = store.get_bytes(hash).await?;
435                    let hs =
436                        HashSeq::try_from(bytes).map_err(|_| e!(HandleGetError::InvalidHashSeq))?;
437                    hash_seq = Some(hs);
438                    hash_seq.as_ref().unwrap()
439                }
440            };
441            let o = usize::try_from(offset - 1).map_err(|_| e!(HandleGetError::InvalidOffset))?;
442            let Some(hash) = hash_seq.get(o) else {
443                break;
444            };
445            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
446        }
447    }
448    writer
449        .inner
450        .sync()
451        .await
452        .map_err(|e| e!(HandleGetError::ExportBao, e.into()))?;
453
454    Ok(())
455}
456
457pub async fn handle_get<R: RecvStream, W: SendStream>(
458    mut pair: StreamPair<R, W>,
459    store: Store,
460    request: GetRequest,
461) -> n0_error::Result<()> {
462    let res = pair.get_request(|| request.clone()).await;
463    let tracker = handle_read_request_result(&mut pair, res).await?;
464    let mut writer = pair.into_writer(tracker).await?;
465    let res = handle_get_impl(store, request, &mut writer).await;
466    handle_write_result(&mut writer, res).await?;
467    Ok(())
468}
469
470#[stack_error(derive, add_meta, from_sources)]
471pub enum HandleGetManyError {
472    #[error(transparent)]
473    ExportBao { source: ExportBaoError },
474}
475
476impl HasErrorCode for HandleGetManyError {
477    fn code(&self) -> VarInt {
478        match self {
479            Self::ExportBao {
480                source: ExportBaoError::ClientError { source, .. },
481                ..
482            } => source.code(),
483            _ => ERR_INTERNAL,
484        }
485    }
486}
487
488/// Handle a single get request.
489///
490/// Requires a database, the request, and a writer.
491async fn handle_get_many_impl<W: SendStream>(
492    store: Store,
493    request: GetManyRequest,
494    writer: &mut ProgressWriter<W>,
495) -> Result<(), HandleGetManyError> {
496    debug!("get_many received request");
497    let request_ranges = request.ranges.iter_infinite();
498    for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
499        if !ranges.is_empty() {
500            send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
501        }
502    }
503    Ok(())
504}
505
506pub async fn handle_get_many<R: RecvStream, W: SendStream>(
507    mut pair: StreamPair<R, W>,
508    store: Store,
509    request: GetManyRequest,
510) -> n0_error::Result<()> {
511    let res = pair.get_many_request(|| request.clone()).await;
512    let tracker = handle_read_request_result(&mut pair, res).await?;
513    let mut writer = pair.into_writer(tracker).await?;
514    let res = handle_get_many_impl(store, request, &mut writer).await;
515    handle_write_result(&mut writer, res).await?;
516    Ok(())
517}
518
519#[stack_error(derive, add_meta, from_sources)]
520pub enum HandlePushError {
521    #[error(transparent)]
522    ExportBao { source: ExportBaoError },
523
524    #[error("Invalid hash sequence")]
525    InvalidHashSeq {},
526
527    #[error(transparent)]
528    Request { source: RequestError },
529}
530
531impl HasErrorCode for HandlePushError {
532    fn code(&self) -> VarInt {
533        match self {
534            Self::ExportBao {
535                source: ExportBaoError::ClientError { source, .. },
536                ..
537            } => source.code(),
538            _ => ERR_INTERNAL,
539        }
540    }
541}
542
543/// Handle a single push request.
544///
545/// Requires a database, the request, and a reader.
546async fn handle_push_impl<R: RecvStream>(
547    store: Store,
548    request: PushRequest,
549    reader: &mut ProgressReader<R>,
550) -> Result<(), HandlePushError> {
551    let hash = request.hash;
552    debug!(%hash, "push received request");
553    let mut request_ranges = request.ranges.iter_infinite();
554    let root_ranges = request_ranges.next().expect("infinite iterator");
555    if !root_ranges.is_empty() {
556        // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress
557        store
558            .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
559            .await?;
560    }
561    if request.ranges.is_blob() {
562        debug!("push request complete");
563        return Ok(());
564    }
565    // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now!
566    let hash_seq = store.get_bytes(hash).await?;
567    let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| e!(HandlePushError::InvalidHashSeq))?;
568    for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
569        if child_ranges.is_empty() {
570            continue;
571        }
572        store
573            .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
574            .await?;
575    }
576    Ok(())
577}
578
579pub async fn handle_push<R: RecvStream, W: SendStream>(
580    mut pair: StreamPair<R, W>,
581    store: Store,
582    request: PushRequest,
583) -> n0_error::Result<()> {
584    let res = pair.push_request(|| request.clone()).await;
585    let tracker = handle_read_request_result(&mut pair, res).await?;
586    let mut reader = pair.into_reader(tracker).await?;
587    let res = handle_push_impl(store, request, &mut reader).await;
588    handle_read_result(&mut reader, res).await?;
589    Ok(())
590}
591
592/// Send a blob to the client.
593pub(crate) async fn send_blob<W: SendStream>(
594    store: &Store,
595    index: u64,
596    hash: Hash,
597    ranges: ChunkRanges,
598    writer: &mut ProgressWriter<W>,
599) -> ExportBaoResult<()> {
600    store
601        .export_bao(hash, ranges)
602        .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
603        .await
604}
605
606#[stack_error(derive, add_meta, std_sources, from_sources)]
607pub enum HandleObserveError {
608    #[error("observe stream closed")]
609    ObserveStreamClosed {},
610
611    #[error(transparent)]
612    RemoteClosed { source: io::Error },
613}
614
615impl HasErrorCode for HandleObserveError {
616    fn code(&self) -> VarInt {
617        ERR_INTERNAL
618    }
619}
620
621/// Handle a single push request.
622///
623/// Requires a database, the request, and a reader.
624async fn handle_observe_impl<W: SendStream>(
625    store: Store,
626    request: ObserveRequest,
627    writer: &mut ProgressWriter<W>,
628) -> std::result::Result<(), HandleObserveError> {
629    let mut stream = store
630        .observe(request.hash)
631        .stream()
632        .await
633        .map_err(|_| e!(HandleObserveError::ObserveStreamClosed))?;
634    let mut old = stream
635        .next()
636        .await
637        .ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
638    // send the initial bitfield
639    send_observe_item(writer, &old).await?;
640    // send updates until the remote loses interest
641    loop {
642        select! {
643            new = stream.next() => {
644                let new = new.ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
645                let diff = old.diff(&new);
646                if diff.is_empty() {
647                    continue;
648                }
649                send_observe_item(writer, &diff).await?;
650                old = new;
651            }
652            _ = writer.inner.stopped() => {
653                debug!("observer closed");
654                break;
655            }
656        }
657    }
658    Ok(())
659}
660
661async fn send_observe_item<W: SendStream>(
662    writer: &mut ProgressWriter<W>,
663    item: &Bitfield,
664) -> io::Result<()> {
665    let item = ObserveItem::from(item);
666    let len = writer.inner.write_length_prefixed(item).await?;
667    writer.context.log_other_write(len);
668    Ok(())
669}
670
671pub async fn handle_observe<R: RecvStream, W: SendStream>(
672    mut pair: StreamPair<R, W>,
673    store: Store,
674    request: ObserveRequest,
675) -> n0_error::Result<()> {
676    let res = pair.observe_request(|| request.clone()).await;
677    let tracker = handle_read_request_result(&mut pair, res).await?;
678    let mut writer = pair.into_writer(tracker).await?;
679    let res = handle_observe_impl(store, request, &mut writer).await;
680    handle_write_result(&mut writer, res).await?;
681    Ok(())
682}
683
684pub struct ProgressReader<R: RecvStream = DefaultReader> {
685    inner: R,
686    context: ReaderContext,
687}
688
689impl<R: RecvStream> ProgressReader<R> {
690    async fn transfer_aborted(&self) {
691        self.context
692            .tracker
693            .transfer_aborted(|| Box::new(self.context.stats()))
694            .await
695            .ok();
696    }
697
698    async fn transfer_completed(&self) {
699        self.context
700            .tracker
701            .transfer_completed(|| Box::new(self.context.stats()))
702            .await
703            .ok();
704    }
705}