1use std::{
3 collections::{HashMap, HashSet},
4 fmt::Debug,
5 future::{Future, IntoFuture},
6 sync::Arc,
7};
8
9use genawaiter::sync::Gen;
10use iroh::{Endpoint, EndpointId};
11use irpc::{channel::mpsc, rpc_requests};
12use n0_error::{anyerr, Result};
13use n0_future::{future, stream, task::JoinSet, BufferedStreamExt, Stream, StreamExt};
14use rand::seq::SliceRandom;
15use serde::{de::Error, Deserialize, Serialize};
16use tracing::instrument::Instrument;
17
18use super::Store;
19use crate::{
20 protocol::{GetManyRequest, GetRequest},
21 util::{
22 connection_pool::ConnectionPool,
23 sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink},
24 },
25 BlobFormat, Hash, HashAndFormat,
26};
27
28#[derive(Debug, Clone)]
29pub struct Downloader {
30 client: irpc::Client<SwarmProtocol>,
31}
32
33#[rpc_requests(message = SwarmMsg, alias = "Msg", rpc_feature = "rpc")]
34#[derive(Debug, Serialize, Deserialize)]
35enum SwarmProtocol {
36 #[rpc(tx = mpsc::Sender<DownloadProgressItem>)]
37 Download(DownloadRequest),
38}
39
40struct DownloaderActor {
41 store: Store,
42 pool: ConnectionPool,
43 tasks: JoinSet<()>,
44 running: HashSet<n0_future::task::Id>,
45}
46
47#[derive(Debug, Serialize, Deserialize)]
48pub enum DownloadProgressItem {
49 #[serde(skip)]
50 Error(n0_error::AnyError),
51 TryProvider {
52 id: EndpointId,
53 request: Arc<GetRequest>,
54 },
55 ProviderFailed {
56 id: EndpointId,
57 request: Arc<GetRequest>,
58 },
59 PartComplete {
60 request: Arc<GetRequest>,
61 },
62 Progress(u64),
63 DownloadError,
64}
65
66impl DownloaderActor {
67 fn new(store: Store, endpoint: Endpoint) -> Self {
68 Self {
69 store,
70 pool: ConnectionPool::new(endpoint, crate::ALPN, Default::default()),
71 tasks: JoinSet::new(),
72 running: HashSet::new(),
73 }
74 }
75
76 async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<SwarmMsg>) {
77 while let Some(msg) = rx.recv().await {
78 match msg {
79 SwarmMsg::Download(request) => {
80 self.spawn(handle_download(
81 self.store.clone(),
82 self.pool.clone(),
83 request,
84 ));
85 }
86 }
87 }
88 }
89
90 fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
91 let span = tracing::Span::current();
92 let id = self.tasks.spawn(fut.instrument(span)).id();
93 self.running.insert(id);
94 }
95}
96
97async fn handle_download(store: Store, pool: ConnectionPool, msg: DownloadMsg) {
98 let DownloadMsg { inner, mut tx, .. } = msg;
99 if let Err(cause) = handle_download_impl(store, pool, inner, &mut tx).await {
100 tx.send(DownloadProgressItem::Error(cause)).await.ok();
101 }
102}
103
104async fn handle_download_impl(
105 store: Store,
106 pool: ConnectionPool,
107 request: DownloadRequest,
108 tx: &mut mpsc::Sender<DownloadProgressItem>,
109) -> Result<()> {
110 match request.strategy {
111 SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?,
112 SplitStrategy::None => match request.request {
113 FiniteRequest::Get(get) => {
114 let sink = IrpcSenderRefSink(tx);
115 execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?;
116 }
117 FiniteRequest::GetMany(_) => {
118 handle_download_split_impl(store, pool, request, tx).await?
119 }
120 },
121 }
122 Ok(())
123}
124
125async fn handle_download_split_impl(
126 store: Store,
127 pool: ConnectionPool,
128 request: DownloadRequest,
129 tx: &mut mpsc::Sender<DownloadProgressItem>,
130) -> Result<()> {
131 let providers = request.providers;
132 let requests = split_request(&request.request, &providers, &pool, &store, Drain).await?;
133 let (progress_tx, progress_rx) = tokio::sync::mpsc::channel(32);
134 let mut futs = stream::iter(requests.into_iter().enumerate())
135 .map(|(id, request)| {
136 let pool = pool.clone();
137 let providers = providers.clone();
138 let store = store.clone();
139 let progress_tx = progress_tx.clone();
140 async move {
141 let hash = request.hash;
142 let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgressItem)>(16);
143 progress_tx.send(rx).await.ok();
144 let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x));
145 let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await;
146 (hash, res)
147 }
148 })
149 .buffered_unordered(32);
150 let mut progress_stream = {
151 let mut offsets = HashMap::new();
152 let mut total = 0;
153 into_stream(progress_rx)
154 .flat_map(into_stream)
155 .map(move |(id, item)| match item {
156 DownloadProgressItem::Progress(offset) => {
157 total += offset;
158 if let Some(prev) = offsets.insert(id, offset) {
159 total -= prev;
160 }
161 DownloadProgressItem::Progress(total)
162 }
163 x => x,
164 })
165 };
166 loop {
167 tokio::select! {
168 Some(item) = progress_stream.next() => {
169 tx.send(item).await?;
170 },
171 res = futs.next() => {
172 match res {
173 Some((_hash, Ok(()))) => {
174 }
175 Some((_hash, Err(_e))) => {
176 tx.send(DownloadProgressItem::DownloadError).await?;
177 }
178 None => break,
179 }
180 }
181 _ = tx.closed() => {
182 break;
184 }
185 }
186 }
187 Ok(())
188}
189
190fn into_stream<T>(mut recv: tokio::sync::mpsc::Receiver<T>) -> impl Stream<Item = T> {
191 Gen::new(|co| async move {
192 while let Some(item) = recv.recv().await {
193 co.yield_(item).await;
194 }
195 })
196}
197
198#[derive(Debug, Serialize, Deserialize, derive_more::From)]
199pub enum FiniteRequest {
200 Get(GetRequest),
201 GetMany(GetManyRequest),
202}
203
204pub trait SupportedRequest {
205 fn into_request(self) -> FiniteRequest;
206}
207
208impl<I: Into<Hash>, T: IntoIterator<Item = I>> SupportedRequest for T {
209 fn into_request(self) -> FiniteRequest {
210 let hashes = self.into_iter().map(Into::into).collect::<GetManyRequest>();
211 FiniteRequest::GetMany(hashes)
212 }
213}
214
215impl SupportedRequest for GetRequest {
216 fn into_request(self) -> FiniteRequest {
217 self.into()
218 }
219}
220
221impl SupportedRequest for GetManyRequest {
222 fn into_request(self) -> FiniteRequest {
223 self.into()
224 }
225}
226
227impl SupportedRequest for Hash {
228 fn into_request(self) -> FiniteRequest {
229 GetRequest::blob(self).into()
230 }
231}
232
233impl SupportedRequest for HashAndFormat {
234 fn into_request(self) -> FiniteRequest {
235 (match self.format {
236 BlobFormat::Raw => GetRequest::blob(self.hash),
237 BlobFormat::HashSeq => GetRequest::all(self.hash),
238 })
239 .into()
240 }
241}
242
243#[derive(Debug, Serialize, Deserialize)]
244pub struct AddProviderRequest {
245 pub hash: Hash,
246 pub providers: Vec<EndpointId>,
247}
248
249#[derive(Debug)]
250pub struct DownloadRequest {
251 pub request: FiniteRequest,
252 pub providers: Arc<dyn ContentDiscovery>,
253 pub strategy: SplitStrategy,
254}
255
256impl DownloadRequest {
257 pub fn new(
258 request: impl SupportedRequest,
259 providers: impl ContentDiscovery,
260 strategy: SplitStrategy,
261 ) -> Self {
262 Self {
263 request: request.into_request(),
264 providers: Arc::new(providers),
265 strategy,
266 }
267 }
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271pub enum SplitStrategy {
272 None,
273 Split,
274}
275
276impl Serialize for DownloadRequest {
277 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
278 where
279 S: serde::Serializer,
280 {
281 Err(serde::ser::Error::custom(
282 "cannot serialize DownloadRequest",
283 ))
284 }
285}
286
287impl<'de> Deserialize<'de> for DownloadRequest {
289 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
290 where
291 D: serde::Deserializer<'de>,
292 {
293 Err(D::Error::custom("cannot deserialize DownloadRequest"))
294 }
295}
296
297pub type DownloadOptions = DownloadRequest;
298
299pub struct DownloadProgress {
300 fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>,
301}
302
303impl DownloadProgress {
304 fn new(fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>) -> Self {
305 Self { fut }
306 }
307
308 pub async fn stream(self) -> irpc::Result<impl Stream<Item = DownloadProgressItem> + Unpin> {
309 let rx = self.fut.await?;
310 Ok(Box::pin(rx.into_stream().map(|item| match item {
311 Ok(item) => item,
312 Err(e) => DownloadProgressItem::Error(e.into()),
313 })))
314 }
315
316 async fn complete(self) -> Result<()> {
317 let rx = self.fut.await?;
318 let stream = rx.into_stream();
319 tokio::pin!(stream);
320 while let Some(item) = stream.next().await {
321 match item? {
322 DownloadProgressItem::Error(e) => Err(e)?,
323 DownloadProgressItem::DownloadError => {
324 n0_error::bail_any!("Download error");
325 }
326 _ => {}
327 }
328 }
329 Ok(())
330 }
331}
332
333impl IntoFuture for DownloadProgress {
334 type Output = Result<()>;
335 type IntoFuture = future::Boxed<Self::Output>;
336
337 fn into_future(self) -> Self::IntoFuture {
338 Box::pin(self.complete())
339 }
340}
341
342impl Downloader {
343 pub fn new(store: &Store, endpoint: &Endpoint) -> Self {
344 let (tx, rx) = tokio::sync::mpsc::channel::<SwarmMsg>(32);
345 let actor = DownloaderActor::new(store.clone(), endpoint.clone());
346 n0_future::task::spawn(actor.run(rx));
347 Self { client: tx.into() }
348 }
349
350 pub fn download(
351 &self,
352 request: impl SupportedRequest,
353 providers: impl ContentDiscovery,
354 ) -> DownloadProgress {
355 let request = request.into_request();
356 let providers = Arc::new(providers);
357 self.download_with_opts(DownloadOptions {
358 request,
359 providers,
360 strategy: SplitStrategy::None,
361 })
362 }
363
364 pub fn download_with_opts(&self, options: DownloadOptions) -> DownloadProgress {
365 let fut = self.client.server_streaming(options, 32);
366 DownloadProgress::new(Box::pin(fut))
367 }
368}
369
370async fn split_request<'a>(
372 request: &'a FiniteRequest,
373 providers: &Arc<dyn ContentDiscovery>,
374 pool: &ConnectionPool,
375 store: &Store,
376 progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
377) -> Result<Box<dyn Iterator<Item = GetRequest> + Send + 'a>> {
378 Ok(match request {
379 FiniteRequest::Get(req) => {
380 let Some(_first) = req.ranges.iter_infinite().next() else {
381 return Ok(Box::new(std::iter::empty()));
382 };
383 let first = GetRequest::blob(req.hash);
384 execute_get(pool, Arc::new(first), providers, store, progress).await?;
385 let size = store.observe(req.hash).await?.size();
386 n0_error::ensure_any!(size % 32 == 0, "Size is not a multiple of 32");
387 let n = size / 32;
388 Box::new(
389 req.ranges
390 .iter_infinite()
391 .take(n as usize + 1)
392 .enumerate()
393 .filter_map(|(i, ranges)| {
394 if i != 0 && !ranges.is_empty() {
395 Some(
396 GetRequest::builder()
397 .offset(i as u64, ranges.clone())
398 .build(req.hash),
399 )
400 } else {
401 None
402 }
403 }),
404 )
405 }
406 FiniteRequest::GetMany(req) => Box::new(
407 req.hashes
408 .iter()
409 .enumerate()
410 .map(|(i, hash)| GetRequest::blob_ranges(*hash, req.ranges[i as u64].clone())),
411 ),
412 })
413}
414
415async fn execute_get(
428 pool: &ConnectionPool,
429 request: Arc<GetRequest>,
430 providers: &Arc<dyn ContentDiscovery>,
431 store: &Store,
432 mut progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
433) -> Result<()> {
434 let remote = store.remote();
435 let mut providers = providers.find_providers(request.content());
436 while let Some(provider) = providers.next().await {
437 progress
438 .send(DownloadProgressItem::TryProvider {
439 id: provider,
440 request: request.clone(),
441 })
442 .await?;
443 let conn = pool.get_or_connect(provider);
444 let local = remote.local_for_request(request.clone()).await?;
445 if local.is_complete() {
446 return Ok(());
447 }
448 let local_bytes = local.local_bytes();
449 let Ok(conn) = conn.await else {
450 progress
451 .send(DownloadProgressItem::ProviderFailed {
452 id: provider,
453 request: request.clone(),
454 })
455 .await?;
456 continue;
457 };
458 match remote
459 .execute_get_sink(
460 conn.clone(),
461 local.missing(),
462 (&mut progress).with_map(move |x| DownloadProgressItem::Progress(x + local_bytes)),
463 )
464 .await
465 {
466 Ok(_stats) => {
467 progress
468 .send(DownloadProgressItem::PartComplete {
469 request: request.clone(),
470 })
471 .await?;
472 return Ok(());
473 }
474 Err(_cause) => {
475 progress
476 .send(DownloadProgressItem::ProviderFailed {
477 id: provider,
478 request: request.clone(),
479 })
480 .await?;
481 continue;
482 }
483 }
484 }
485 Err(anyerr!("Unable to download {}", request.hash))
486}
487
488pub trait ContentDiscovery: Debug + Send + Sync + 'static {
490 fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed<EndpointId>;
491}
492
493impl<C, I> ContentDiscovery for C
494where
495 C: Debug + Clone + IntoIterator<Item = I> + Send + Sync + 'static,
496 C::IntoIter: Send + Sync + 'static,
497 I: Into<EndpointId> + Send + Sync + 'static,
498{
499 fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<EndpointId> {
500 let providers = self.clone();
501 n0_future::stream::iter(providers.into_iter().map(Into::into)).boxed()
502 }
503}
504
505#[derive(derive_more::Debug)]
506pub struct Shuffled {
507 nodes: Vec<EndpointId>,
508}
509
510impl Shuffled {
511 pub fn new(nodes: Vec<EndpointId>) -> Self {
512 Self { nodes }
513 }
514}
515
516impl ContentDiscovery for Shuffled {
517 fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<EndpointId> {
518 let mut nodes = self.nodes.clone();
519 nodes.shuffle(&mut rand::rng());
520 n0_future::stream::iter(nodes).boxed()
521 }
522}
523
524#[cfg(test)]
525#[cfg(feature = "fs-store")]
526mod tests {
527 use std::ops::Deref;
528
529 use bao_tree::ChunkRanges;
530 use n0_future::StreamExt;
531 use testresult::TestResult;
532
533 use crate::{
534 api::{
535 blobs::AddBytesOptions,
536 downloader::{DownloadOptions, Downloader, Shuffled, SplitStrategy},
537 },
538 hashseq::HashSeq,
539 protocol::{GetManyRequest, GetRequest},
540 tests::node_test_setup_fs,
541 };
542
543 #[tokio::test]
544 #[ignore = "todo"]
545 async fn downloader_get_many_smoke() -> TestResult<()> {
546 let testdir = tempfile::tempdir()?;
547 let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
548 let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
549 let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
550 let tt1 = store1.add_slice("hello world").await?;
551 let tt2 = store2.add_slice("hello world 2").await?;
552 let node1_addr = r1.endpoint().addr();
553 let node1_id = node1_addr.id;
554 let node2_addr = r2.endpoint().addr();
555 let node2_id = node2_addr.id;
556 let swarm = Downloader::new(&store3, r3.endpoint());
557 sp3.add_endpoint_info(node1_addr.clone());
558 sp3.add_endpoint_info(node2_addr.clone());
559 let request = GetManyRequest::builder()
560 .hash(tt1.hash, ChunkRanges::all())
561 .hash(tt2.hash, ChunkRanges::all())
562 .build();
563 let mut progress = swarm
564 .download(request, Shuffled::new(vec![node1_id, node2_id]))
565 .stream()
566 .await?;
567 while progress.next().await.is_some() {}
568 assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world");
569 assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2");
570 Ok(())
571 }
572
573 #[tokio::test]
574 async fn downloader_get_smoke() -> TestResult<()> {
575 let testdir = tempfile::tempdir()?;
577 let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
578 let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
579 let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
580 let tt1 = store1.add_slice(vec![1; 10000000]).await?;
581 let tt2 = store2.add_slice(vec![2; 10000000]).await?;
582 let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
583 let root = store1
584 .add_bytes_with_opts(AddBytesOptions {
585 data: hs.clone().into(),
586 format: crate::BlobFormat::HashSeq,
587 })
588 .await?;
589 let node1_addr = r1.endpoint().addr();
590 let node1_id = node1_addr.id;
591 let node2_addr = r2.endpoint().addr();
592 let node2_id = node2_addr.id;
593 let swarm = Downloader::new(&store3, r3.endpoint());
594 sp3.add_endpoint_info(node1_addr.clone());
595 sp3.add_endpoint_info(node2_addr.clone());
596 let request = GetRequest::builder()
597 .root(ChunkRanges::all())
598 .next(ChunkRanges::all())
599 .next(ChunkRanges::all())
600 .build(root.hash);
601 if true {
602 let mut progress = swarm
603 .download_with_opts(DownloadOptions::new(
604 request,
605 [node1_id, node2_id],
606 SplitStrategy::Split,
607 ))
608 .stream()
609 .await?;
610 while progress.next().await.is_some() {}
611 }
612 if false {
613 let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?;
614 let remote = store3.remote();
615 let _rh = remote
616 .execute_get(
617 conn.clone(),
618 GetRequest::builder()
619 .root(ChunkRanges::all())
620 .build(root.hash),
621 )
622 .await?;
623 let h1 = remote.execute_get(
624 conn.clone(),
625 GetRequest::builder()
626 .child(0, ChunkRanges::all())
627 .build(root.hash),
628 );
629 let h2 = remote.execute_get(
630 conn.clone(),
631 GetRequest::builder()
632 .child(1, ChunkRanges::all())
633 .build(root.hash),
634 );
635 h1.await?;
636 h2.await?;
637 }
638 Ok(())
639 }
640
641 #[tokio::test]
642 async fn downloader_get_all() -> TestResult<()> {
643 let testdir = tempfile::tempdir()?;
644 let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
645 let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
646 let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
647 let tt1 = store1.add_slice(vec![1; 10000000]).await?;
648 let tt2 = store2.add_slice(vec![2; 10000000]).await?;
649 let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
650 let root = store1
651 .add_bytes_with_opts(AddBytesOptions {
652 data: hs.clone().into(),
653 format: crate::BlobFormat::HashSeq,
654 })
655 .await?;
656 let node1_addr = r1.endpoint().addr();
657 let node1_id = node1_addr.id;
658 let node2_addr = r2.endpoint().addr();
659 let node2_id = node2_addr.id;
660 let swarm = Downloader::new(&store3, r3.endpoint());
661 sp3.add_endpoint_info(node1_addr.clone());
662 sp3.add_endpoint_info(node2_addr.clone());
663 let request = GetRequest::all(root.hash);
664 let mut progress = swarm
665 .download_with_opts(DownloadOptions::new(
666 request,
667 [node1_id, node2_id],
668 SplitStrategy::Split,
669 ))
670 .stream()
671 .await?;
672 while progress.next().await.is_some() {}
673 Ok(())
674 }
675}