Thanks to visit codestin.com
Credit goes to n0-computer.github.io

iroh_blobs/api/
downloader.rs

1//! API for downloads from multiple nodes.
2use std::{
3    collections::{HashMap, HashSet},
4    fmt::Debug,
5    future::{Future, IntoFuture},
6    sync::Arc,
7};
8
9use genawaiter::sync::Gen;
10use iroh::{Endpoint, EndpointId};
11use irpc::{channel::mpsc, rpc_requests};
12use n0_error::{anyerr, Result};
13use n0_future::{future, stream, task::JoinSet, BufferedStreamExt, Stream, StreamExt};
14use rand::seq::SliceRandom;
15use serde::{de::Error, Deserialize, Serialize};
16use tracing::instrument::Instrument;
17
18use super::Store;
19use crate::{
20    protocol::{GetManyRequest, GetRequest},
21    util::{
22        connection_pool::ConnectionPool,
23        sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink},
24    },
25    BlobFormat, Hash, HashAndFormat,
26};
27
28#[derive(Debug, Clone)]
29pub struct Downloader {
30    client: irpc::Client<SwarmProtocol>,
31}
32
33#[rpc_requests(message = SwarmMsg, alias = "Msg", rpc_feature = "rpc")]
34#[derive(Debug, Serialize, Deserialize)]
35enum SwarmProtocol {
36    #[rpc(tx = mpsc::Sender<DownloadProgressItem>)]
37    Download(DownloadRequest),
38}
39
40struct DownloaderActor {
41    store: Store,
42    pool: ConnectionPool,
43    tasks: JoinSet<()>,
44    running: HashSet<n0_future::task::Id>,
45}
46
47#[derive(Debug, Serialize, Deserialize)]
48pub enum DownloadProgressItem {
49    #[serde(skip)]
50    Error(n0_error::AnyError),
51    TryProvider {
52        id: EndpointId,
53        request: Arc<GetRequest>,
54    },
55    ProviderFailed {
56        id: EndpointId,
57        request: Arc<GetRequest>,
58    },
59    PartComplete {
60        request: Arc<GetRequest>,
61    },
62    Progress(u64),
63    DownloadError,
64}
65
66impl DownloaderActor {
67    fn new(store: Store, endpoint: Endpoint) -> Self {
68        Self {
69            store,
70            pool: ConnectionPool::new(endpoint, crate::ALPN, Default::default()),
71            tasks: JoinSet::new(),
72            running: HashSet::new(),
73        }
74    }
75
76    async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<SwarmMsg>) {
77        while let Some(msg) = rx.recv().await {
78            match msg {
79                SwarmMsg::Download(request) => {
80                    self.spawn(handle_download(
81                        self.store.clone(),
82                        self.pool.clone(),
83                        request,
84                    ));
85                }
86            }
87        }
88    }
89
90    fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
91        let span = tracing::Span::current();
92        let id = self.tasks.spawn(fut.instrument(span)).id();
93        self.running.insert(id);
94    }
95}
96
97async fn handle_download(store: Store, pool: ConnectionPool, msg: DownloadMsg) {
98    let DownloadMsg { inner, mut tx, .. } = msg;
99    if let Err(cause) = handle_download_impl(store, pool, inner, &mut tx).await {
100        tx.send(DownloadProgressItem::Error(cause)).await.ok();
101    }
102}
103
104async fn handle_download_impl(
105    store: Store,
106    pool: ConnectionPool,
107    request: DownloadRequest,
108    tx: &mut mpsc::Sender<DownloadProgressItem>,
109) -> Result<()> {
110    match request.strategy {
111        SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?,
112        SplitStrategy::None => match request.request {
113            FiniteRequest::Get(get) => {
114                let sink = IrpcSenderRefSink(tx);
115                execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?;
116            }
117            FiniteRequest::GetMany(_) => {
118                handle_download_split_impl(store, pool, request, tx).await?
119            }
120        },
121    }
122    Ok(())
123}
124
125async fn handle_download_split_impl(
126    store: Store,
127    pool: ConnectionPool,
128    request: DownloadRequest,
129    tx: &mut mpsc::Sender<DownloadProgressItem>,
130) -> Result<()> {
131    let providers = request.providers;
132    let requests = split_request(&request.request, &providers, &pool, &store, Drain).await?;
133    let (progress_tx, progress_rx) = tokio::sync::mpsc::channel(32);
134    let mut futs = stream::iter(requests.into_iter().enumerate())
135        .map(|(id, request)| {
136            let pool = pool.clone();
137            let providers = providers.clone();
138            let store = store.clone();
139            let progress_tx = progress_tx.clone();
140            async move {
141                let hash = request.hash;
142                let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgressItem)>(16);
143                progress_tx.send(rx).await.ok();
144                let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x));
145                let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await;
146                (hash, res)
147            }
148        })
149        .buffered_unordered(32);
150    let mut progress_stream = {
151        let mut offsets = HashMap::new();
152        let mut total = 0;
153        into_stream(progress_rx)
154            .flat_map(into_stream)
155            .map(move |(id, item)| match item {
156                DownloadProgressItem::Progress(offset) => {
157                    total += offset;
158                    if let Some(prev) = offsets.insert(id, offset) {
159                        total -= prev;
160                    }
161                    DownloadProgressItem::Progress(total)
162                }
163                x => x,
164            })
165    };
166    loop {
167        tokio::select! {
168            Some(item) = progress_stream.next() => {
169                tx.send(item).await?;
170            },
171            res = futs.next() => {
172                match res {
173                    Some((_hash, Ok(()))) => {
174                    }
175                    Some((_hash, Err(_e))) => {
176                        tx.send(DownloadProgressItem::DownloadError).await?;
177                    }
178                    None => break,
179                }
180            }
181            _ = tx.closed() => {
182                // The sender has been closed, we should stop processing.
183                break;
184            }
185        }
186    }
187    Ok(())
188}
189
190fn into_stream<T>(mut recv: tokio::sync::mpsc::Receiver<T>) -> impl Stream<Item = T> {
191    Gen::new(|co| async move {
192        while let Some(item) = recv.recv().await {
193            co.yield_(item).await;
194        }
195    })
196}
197
198#[derive(Debug, Serialize, Deserialize, derive_more::From)]
199pub enum FiniteRequest {
200    Get(GetRequest),
201    GetMany(GetManyRequest),
202}
203
204pub trait SupportedRequest {
205    fn into_request(self) -> FiniteRequest;
206}
207
208impl<I: Into<Hash>, T: IntoIterator<Item = I>> SupportedRequest for T {
209    fn into_request(self) -> FiniteRequest {
210        let hashes = self.into_iter().map(Into::into).collect::<GetManyRequest>();
211        FiniteRequest::GetMany(hashes)
212    }
213}
214
215impl SupportedRequest for GetRequest {
216    fn into_request(self) -> FiniteRequest {
217        self.into()
218    }
219}
220
221impl SupportedRequest for GetManyRequest {
222    fn into_request(self) -> FiniteRequest {
223        self.into()
224    }
225}
226
227impl SupportedRequest for Hash {
228    fn into_request(self) -> FiniteRequest {
229        GetRequest::blob(self).into()
230    }
231}
232
233impl SupportedRequest for HashAndFormat {
234    fn into_request(self) -> FiniteRequest {
235        (match self.format {
236            BlobFormat::Raw => GetRequest::blob(self.hash),
237            BlobFormat::HashSeq => GetRequest::all(self.hash),
238        })
239        .into()
240    }
241}
242
243#[derive(Debug, Serialize, Deserialize)]
244pub struct AddProviderRequest {
245    pub hash: Hash,
246    pub providers: Vec<EndpointId>,
247}
248
249#[derive(Debug)]
250pub struct DownloadRequest {
251    pub request: FiniteRequest,
252    pub providers: Arc<dyn ContentDiscovery>,
253    pub strategy: SplitStrategy,
254}
255
256impl DownloadRequest {
257    pub fn new(
258        request: impl SupportedRequest,
259        providers: impl ContentDiscovery,
260        strategy: SplitStrategy,
261    ) -> Self {
262        Self {
263            request: request.into_request(),
264            providers: Arc::new(providers),
265            strategy,
266        }
267    }
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271pub enum SplitStrategy {
272    None,
273    Split,
274}
275
276impl Serialize for DownloadRequest {
277    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
278    where
279        S: serde::Serializer,
280    {
281        Err(serde::ser::Error::custom(
282            "cannot serialize DownloadRequest",
283        ))
284    }
285}
286
287// Implement Deserialize to always fail
288impl<'de> Deserialize<'de> for DownloadRequest {
289    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
290    where
291        D: serde::Deserializer<'de>,
292    {
293        Err(D::Error::custom("cannot deserialize DownloadRequest"))
294    }
295}
296
297pub type DownloadOptions = DownloadRequest;
298
299pub struct DownloadProgress {
300    fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>,
301}
302
303impl DownloadProgress {
304    fn new(fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>) -> Self {
305        Self { fut }
306    }
307
308    pub async fn stream(self) -> irpc::Result<impl Stream<Item = DownloadProgressItem> + Unpin> {
309        let rx = self.fut.await?;
310        Ok(Box::pin(rx.into_stream().map(|item| match item {
311            Ok(item) => item,
312            Err(e) => DownloadProgressItem::Error(e.into()),
313        })))
314    }
315
316    async fn complete(self) -> Result<()> {
317        let rx = self.fut.await?;
318        let stream = rx.into_stream();
319        tokio::pin!(stream);
320        while let Some(item) = stream.next().await {
321            match item? {
322                DownloadProgressItem::Error(e) => Err(e)?,
323                DownloadProgressItem::DownloadError => {
324                    n0_error::bail_any!("Download error");
325                }
326                _ => {}
327            }
328        }
329        Ok(())
330    }
331}
332
333impl IntoFuture for DownloadProgress {
334    type Output = Result<()>;
335    type IntoFuture = future::Boxed<Self::Output>;
336
337    fn into_future(self) -> Self::IntoFuture {
338        Box::pin(self.complete())
339    }
340}
341
342impl Downloader {
343    pub fn new(store: &Store, endpoint: &Endpoint) -> Self {
344        let (tx, rx) = tokio::sync::mpsc::channel::<SwarmMsg>(32);
345        let actor = DownloaderActor::new(store.clone(), endpoint.clone());
346        n0_future::task::spawn(actor.run(rx));
347        Self { client: tx.into() }
348    }
349
350    pub fn download(
351        &self,
352        request: impl SupportedRequest,
353        providers: impl ContentDiscovery,
354    ) -> DownloadProgress {
355        let request = request.into_request();
356        let providers = Arc::new(providers);
357        self.download_with_opts(DownloadOptions {
358            request,
359            providers,
360            strategy: SplitStrategy::None,
361        })
362    }
363
364    pub fn download_with_opts(&self, options: DownloadOptions) -> DownloadProgress {
365        let fut = self.client.server_streaming(options, 32);
366        DownloadProgress::new(Box::pin(fut))
367    }
368}
369
370/// Split a request into multiple requests that can be run in parallel.
371async fn split_request<'a>(
372    request: &'a FiniteRequest,
373    providers: &Arc<dyn ContentDiscovery>,
374    pool: &ConnectionPool,
375    store: &Store,
376    progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
377) -> Result<Box<dyn Iterator<Item = GetRequest> + Send + 'a>> {
378    Ok(match request {
379        FiniteRequest::Get(req) => {
380            let Some(_first) = req.ranges.iter_infinite().next() else {
381                return Ok(Box::new(std::iter::empty()));
382            };
383            let first = GetRequest::blob(req.hash);
384            execute_get(pool, Arc::new(first), providers, store, progress).await?;
385            let size = store.observe(req.hash).await?.size();
386            n0_error::ensure_any!(size % 32 == 0, "Size is not a multiple of 32");
387            let n = size / 32;
388            Box::new(
389                req.ranges
390                    .iter_infinite()
391                    .take(n as usize + 1)
392                    .enumerate()
393                    .filter_map(|(i, ranges)| {
394                        if i != 0 && !ranges.is_empty() {
395                            Some(
396                                GetRequest::builder()
397                                    .offset(i as u64, ranges.clone())
398                                    .build(req.hash),
399                            )
400                        } else {
401                            None
402                        }
403                    }),
404            )
405        }
406        FiniteRequest::GetMany(req) => Box::new(
407            req.hashes
408                .iter()
409                .enumerate()
410                .map(|(i, hash)| GetRequest::blob_ranges(*hash, req.ranges[i as u64].clone())),
411        ),
412    })
413}
414
415/// Execute a get request sequentially for multiple providers.
416///
417/// It will try each provider in order
418/// until it finds one that can fulfill the request. When trying a new provider,
419/// it takes the progress from the previous providers into account, so e.g.
420/// if the first provider had the first 10% of the data, it will only ask the next
421/// provider for the remaining 90%.
422///
423/// This is fully sequential, so there will only be one request in flight at a time.
424///
425/// If the request is not complete after trying all providers, it will return an error.
426/// If the provider stream never ends, it will try indefinitely.
427async fn execute_get(
428    pool: &ConnectionPool,
429    request: Arc<GetRequest>,
430    providers: &Arc<dyn ContentDiscovery>,
431    store: &Store,
432    mut progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
433) -> Result<()> {
434    let remote = store.remote();
435    let mut providers = providers.find_providers(request.content());
436    while let Some(provider) = providers.next().await {
437        progress
438            .send(DownloadProgressItem::TryProvider {
439                id: provider,
440                request: request.clone(),
441            })
442            .await?;
443        let conn = pool.get_or_connect(provider);
444        let local = remote.local_for_request(request.clone()).await?;
445        if local.is_complete() {
446            return Ok(());
447        }
448        let local_bytes = local.local_bytes();
449        let Ok(conn) = conn.await else {
450            progress
451                .send(DownloadProgressItem::ProviderFailed {
452                    id: provider,
453                    request: request.clone(),
454                })
455                .await?;
456            continue;
457        };
458        match remote
459            .execute_get_sink(
460                conn.clone(),
461                local.missing(),
462                (&mut progress).with_map(move |x| DownloadProgressItem::Progress(x + local_bytes)),
463            )
464            .await
465        {
466            Ok(_stats) => {
467                progress
468                    .send(DownloadProgressItem::PartComplete {
469                        request: request.clone(),
470                    })
471                    .await?;
472                return Ok(());
473            }
474            Err(_cause) => {
475                progress
476                    .send(DownloadProgressItem::ProviderFailed {
477                        id: provider,
478                        request: request.clone(),
479                    })
480                    .await?;
481                continue;
482            }
483        }
484    }
485    Err(anyerr!("Unable to download {}", request.hash))
486}
487
488/// Trait for pluggable content discovery strategies.
489pub trait ContentDiscovery: Debug + Send + Sync + 'static {
490    fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed<EndpointId>;
491}
492
493impl<C, I> ContentDiscovery for C
494where
495    C: Debug + Clone + IntoIterator<Item = I> + Send + Sync + 'static,
496    C::IntoIter: Send + Sync + 'static,
497    I: Into<EndpointId> + Send + Sync + 'static,
498{
499    fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<EndpointId> {
500        let providers = self.clone();
501        n0_future::stream::iter(providers.into_iter().map(Into::into)).boxed()
502    }
503}
504
505#[derive(derive_more::Debug)]
506pub struct Shuffled {
507    nodes: Vec<EndpointId>,
508}
509
510impl Shuffled {
511    pub fn new(nodes: Vec<EndpointId>) -> Self {
512        Self { nodes }
513    }
514}
515
516impl ContentDiscovery for Shuffled {
517    fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<EndpointId> {
518        let mut nodes = self.nodes.clone();
519        nodes.shuffle(&mut rand::rng());
520        n0_future::stream::iter(nodes).boxed()
521    }
522}
523
524#[cfg(test)]
525#[cfg(feature = "fs-store")]
526mod tests {
527    use std::ops::Deref;
528
529    use bao_tree::ChunkRanges;
530    use n0_future::StreamExt;
531    use testresult::TestResult;
532
533    use crate::{
534        api::{
535            blobs::AddBytesOptions,
536            downloader::{DownloadOptions, Downloader, Shuffled, SplitStrategy},
537        },
538        hashseq::HashSeq,
539        protocol::{GetManyRequest, GetRequest},
540        tests::node_test_setup_fs,
541    };
542
543    #[tokio::test]
544    #[ignore = "todo"]
545    async fn downloader_get_many_smoke() -> TestResult<()> {
546        let testdir = tempfile::tempdir()?;
547        let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
548        let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
549        let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
550        let tt1 = store1.add_slice("hello world").await?;
551        let tt2 = store2.add_slice("hello world 2").await?;
552        let node1_addr = r1.endpoint().addr();
553        let node1_id = node1_addr.id;
554        let node2_addr = r2.endpoint().addr();
555        let node2_id = node2_addr.id;
556        let swarm = Downloader::new(&store3, r3.endpoint());
557        sp3.add_endpoint_info(node1_addr.clone());
558        sp3.add_endpoint_info(node2_addr.clone());
559        let request = GetManyRequest::builder()
560            .hash(tt1.hash, ChunkRanges::all())
561            .hash(tt2.hash, ChunkRanges::all())
562            .build();
563        let mut progress = swarm
564            .download(request, Shuffled::new(vec![node1_id, node2_id]))
565            .stream()
566            .await?;
567        while progress.next().await.is_some() {}
568        assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world");
569        assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2");
570        Ok(())
571    }
572
573    #[tokio::test]
574    async fn downloader_get_smoke() -> TestResult<()> {
575        // tracing_subscriber::fmt::try_init().ok();
576        let testdir = tempfile::tempdir()?;
577        let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
578        let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
579        let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
580        let tt1 = store1.add_slice(vec![1; 10000000]).await?;
581        let tt2 = store2.add_slice(vec![2; 10000000]).await?;
582        let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
583        let root = store1
584            .add_bytes_with_opts(AddBytesOptions {
585                data: hs.clone().into(),
586                format: crate::BlobFormat::HashSeq,
587            })
588            .await?;
589        let node1_addr = r1.endpoint().addr();
590        let node1_id = node1_addr.id;
591        let node2_addr = r2.endpoint().addr();
592        let node2_id = node2_addr.id;
593        let swarm = Downloader::new(&store3, r3.endpoint());
594        sp3.add_endpoint_info(node1_addr.clone());
595        sp3.add_endpoint_info(node2_addr.clone());
596        let request = GetRequest::builder()
597            .root(ChunkRanges::all())
598            .next(ChunkRanges::all())
599            .next(ChunkRanges::all())
600            .build(root.hash);
601        if true {
602            let mut progress = swarm
603                .download_with_opts(DownloadOptions::new(
604                    request,
605                    [node1_id, node2_id],
606                    SplitStrategy::Split,
607                ))
608                .stream()
609                .await?;
610            while progress.next().await.is_some() {}
611        }
612        if false {
613            let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?;
614            let remote = store3.remote();
615            let _rh = remote
616                .execute_get(
617                    conn.clone(),
618                    GetRequest::builder()
619                        .root(ChunkRanges::all())
620                        .build(root.hash),
621                )
622                .await?;
623            let h1 = remote.execute_get(
624                conn.clone(),
625                GetRequest::builder()
626                    .child(0, ChunkRanges::all())
627                    .build(root.hash),
628            );
629            let h2 = remote.execute_get(
630                conn.clone(),
631                GetRequest::builder()
632                    .child(1, ChunkRanges::all())
633                    .build(root.hash),
634            );
635            h1.await?;
636            h2.await?;
637        }
638        Ok(())
639    }
640
641    #[tokio::test]
642    async fn downloader_get_all() -> TestResult<()> {
643        let testdir = tempfile::tempdir()?;
644        let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
645        let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
646        let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
647        let tt1 = store1.add_slice(vec![1; 10000000]).await?;
648        let tt2 = store2.add_slice(vec![2; 10000000]).await?;
649        let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
650        let root = store1
651            .add_bytes_with_opts(AddBytesOptions {
652                data: hs.clone().into(),
653                format: crate::BlobFormat::HashSeq,
654            })
655            .await?;
656        let node1_addr = r1.endpoint().addr();
657        let node1_id = node1_addr.id;
658        let node2_addr = r2.endpoint().addr();
659        let node2_id = node2_addr.id;
660        let swarm = Downloader::new(&store3, r3.endpoint());
661        sp3.add_endpoint_info(node1_addr.clone());
662        sp3.add_endpoint_info(node2_addr.clone());
663        let request = GetRequest::all(root.hash);
664        let mut progress = swarm
665            .download_with_opts(DownloadOptions::new(
666                request,
667                [node1_id, node2_id],
668                SplitStrategy::Split,
669            ))
670            .stream()
671            .await?;
672        while progress.next().await.is_some() {}
673        Ok(())
674    }
675}