1use std::{fmt::Debug, future::Future, io, time::Duration};
7
8use bao_tree::ChunkRanges;
9use iroh::endpoint::{self, ConnectionError, VarInt};
10use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
11use n0_error::{e, stack_error, Result};
12use n0_future::{time::Instant, StreamExt};
13use serde::{Deserialize, Serialize};
14use tokio::select;
15use tracing::{debug, debug_span, Instrument};
16
17use crate::{
18 api::{
19 blobs::{Bitfield, WriteProgress},
20 ExportBaoError, ExportBaoResult, RequestError, Store,
21 },
22 hashseq::HashSeq,
23 protocol::{
24 GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
25 },
26 provider::events::{
27 ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
28 RequestTracker,
29 },
30 util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
31 Hash,
32};
33pub mod events;
34use events::EventSender;
35
36type DefaultReader = iroh::endpoint::RecvStream;
37type DefaultWriter = iroh::endpoint::SendStream;
38
39#[derive(Debug, Serialize, Deserialize)]
41pub struct TransferStats {
42 pub payload_bytes_sent: u64,
44 pub other_bytes_sent: u64,
48 pub other_bytes_read: u64,
53 pub duration: Duration,
55}
56
57#[derive(Debug)]
59pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
60 t0: Instant,
61 connection_id: u64,
62 reader: R,
63 writer: W,
64 other_bytes_read: u64,
65 events: EventSender,
66}
67
68impl StreamPair {
69 pub async fn accept(
70 conn: &endpoint::Connection,
71 events: EventSender,
72 ) -> Result<Self, ConnectionError> {
73 let (writer, reader) = conn.accept_bi().await?;
74 Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
75 }
76}
77
78impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
79 pub fn stream_id(&self) -> u64 {
80 self.reader.id()
81 }
82
83 pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
84 Self {
85 t0: Instant::now(),
86 connection_id,
87 reader,
88 writer,
89 other_bytes_read: 0,
90 events,
91 }
92 }
93
94 pub async fn read_request(&mut self) -> Result<Request> {
103 let (res, size) = Request::read_async(&mut self.reader).await?;
104 self.other_bytes_read += size as u64;
105 Ok(res)
106 }
107
108 pub async fn into_writer(
110 mut self,
111 tracker: RequestTracker,
112 ) -> Result<ProgressWriter<W>, io::Error> {
113 self.reader.expect_eof().await?;
114 drop(self.reader);
115 Ok(ProgressWriter::new(
116 self.writer,
117 WriterContext {
118 t0: self.t0,
119 other_bytes_read: self.other_bytes_read,
120 payload_bytes_written: 0,
121 other_bytes_written: 0,
122 tracker,
123 },
124 ))
125 }
126
127 pub async fn into_reader(
128 mut self,
129 tracker: RequestTracker,
130 ) -> Result<ProgressReader<R>, io::Error> {
131 self.writer.sync().await?;
132 drop(self.writer);
133 Ok(ProgressReader {
134 inner: self.reader,
135 context: ReaderContext {
136 t0: self.t0,
137 other_bytes_read: self.other_bytes_read,
138 tracker,
139 },
140 })
141 }
142
143 pub async fn get_request(
144 &self,
145 f: impl FnOnce() -> GetRequest,
146 ) -> Result<RequestTracker, ProgressError> {
147 self.events
148 .request(f, self.connection_id, self.reader.id())
149 .await
150 }
151
152 pub async fn get_many_request(
153 &self,
154 f: impl FnOnce() -> GetManyRequest,
155 ) -> Result<RequestTracker, ProgressError> {
156 self.events
157 .request(f, self.connection_id, self.reader.id())
158 .await
159 }
160
161 pub async fn push_request(
162 &self,
163 f: impl FnOnce() -> PushRequest,
164 ) -> Result<RequestTracker, ProgressError> {
165 self.events
166 .request(f, self.connection_id, self.reader.id())
167 .await
168 }
169
170 pub async fn observe_request(
171 &self,
172 f: impl FnOnce() -> ObserveRequest,
173 ) -> Result<RequestTracker, ProgressError> {
174 self.events
175 .request(f, self.connection_id, self.reader.id())
176 .await
177 }
178
179 pub fn stats(&self) -> TransferStats {
180 TransferStats {
181 payload_bytes_sent: 0,
182 other_bytes_sent: 0,
183 other_bytes_read: self.other_bytes_read,
184 duration: self.t0.elapsed(),
185 }
186 }
187}
188
189#[derive(Debug)]
190struct ReaderContext {
191 t0: Instant,
193 other_bytes_read: u64,
195 tracker: RequestTracker,
197}
198
199impl ReaderContext {
200 fn stats(&self) -> TransferStats {
201 TransferStats {
202 payload_bytes_sent: 0,
203 other_bytes_sent: 0,
204 other_bytes_read: self.other_bytes_read,
205 duration: self.t0.elapsed(),
206 }
207 }
208}
209
210#[derive(Debug)]
211pub(crate) struct WriterContext {
212 t0: Instant,
214 other_bytes_read: u64,
216 payload_bytes_written: u64,
218 other_bytes_written: u64,
220 tracker: RequestTracker,
222}
223
224impl WriterContext {
225 fn stats(&self) -> TransferStats {
226 TransferStats {
227 payload_bytes_sent: self.payload_bytes_written,
228 other_bytes_sent: self.other_bytes_written,
229 other_bytes_read: self.other_bytes_read,
230 duration: self.t0.elapsed(),
231 }
232 }
233}
234
235impl WriteProgress for WriterContext {
236 async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
237 let len = len as u64;
238 let end_offset = offset + len;
239 self.payload_bytes_written += len;
240 self.tracker.transfer_progress(len, end_offset).await
241 }
242
243 fn log_other_write(&mut self, len: usize) {
244 self.other_bytes_written += len as u64;
245 }
246
247 async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
248 self.tracker.transfer_started(index, hash, size).await.ok();
249 }
250}
251
252#[derive(Debug)]
254pub struct ProgressWriter<W: SendStream = DefaultWriter> {
255 pub inner: W,
257 pub(crate) context: WriterContext,
258}
259
260impl<W: SendStream> ProgressWriter<W> {
261 fn new(inner: W, context: WriterContext) -> Self {
262 Self { inner, context }
263 }
264
265 async fn transfer_aborted(&self) {
266 self.context
267 .tracker
268 .transfer_aborted(|| Box::new(self.context.stats()))
269 .await
270 .ok();
271 }
272
273 async fn transfer_completed(&self) {
274 self.context
275 .tracker
276 .transfer_completed(|| Box::new(self.context.stats()))
277 .await
278 .ok();
279 }
280}
281
282pub async fn handle_connection(
284 connection: endpoint::Connection,
285 store: Store,
286 progress: EventSender,
287) {
288 let connection_id = connection.stable_id() as u64;
289 let span = debug_span!("connection", connection_id);
290 async move {
291 if let Err(cause) = progress
292 .client_connected(|| ClientConnected {
293 connection_id,
294 endpoint_id: Some(connection.remote_id()),
295 })
296 .await
297 {
298 connection.close(cause.code(), cause.reason());
299 debug!("closing connection: {cause}");
300 return;
301 }
302 while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
303 let span = debug_span!("stream", stream_id = %pair.stream_id());
304 let store = store.clone();
305 n0_future::task::spawn(handle_stream(pair, store).instrument(span));
306 }
307 progress
308 .connection_closed(|| ConnectionClosed { connection_id })
309 .await
310 .ok();
311 }
312 .instrument(span)
313 .await
314}
315
316pub trait ErrorHandler {
318 type W: AsyncStreamWriter;
319 type R: AsyncStreamReader;
320 fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
321 fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
322}
323
324async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
325 pair: &mut StreamPair<R, W>,
326 r: Result<T, E>,
327) -> Result<T, E> {
328 match r {
329 Ok(x) => Ok(x),
330 Err(e) => {
331 pair.writer.reset(e.code()).ok();
332 Err(e)
333 }
334 }
335}
336async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
337 writer: &mut ProgressWriter<W>,
338 r: Result<T, E>,
339) -> Result<T, E> {
340 match r {
341 Ok(x) => {
342 writer.transfer_completed().await;
343 Ok(x)
344 }
345 Err(e) => {
346 writer.inner.reset(e.code()).ok();
347 writer.transfer_aborted().await;
348 Err(e)
349 }
350 }
351}
352async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
353 reader: &mut ProgressReader<R>,
354 r: Result<T, E>,
355) -> Result<T, E> {
356 match r {
357 Ok(x) => {
358 reader.transfer_completed().await;
359 Ok(x)
360 }
361 Err(e) => {
362 reader.inner.stop(e.code()).ok();
363 reader.transfer_aborted().await;
364 Err(e)
365 }
366 }
367}
368
369pub async fn handle_stream<R: RecvStream, W: SendStream>(
370 mut pair: StreamPair<R, W>,
371 store: Store,
372) -> n0_error::Result<()> {
373 let request = pair.read_request().await?;
374 match request {
375 Request::Get(request) => handle_get(pair, store, request).await?,
376 Request::GetMany(request) => handle_get_many(pair, store, request).await?,
377 Request::Observe(request) => handle_observe(pair, store, request).await?,
378 Request::Push(request) => handle_push(pair, store, request).await?,
379 _ => {}
380 }
381 Ok(())
382}
383
384#[stack_error(derive, add_meta, from_sources)]
385pub enum HandleGetError {
386 #[error(transparent)]
387 ExportBao {
388 #[error(std_err)]
389 source: ExportBaoError,
390 },
391 #[error("Invalid hash sequence")]
392 InvalidHashSeq {},
393 #[error("Invalid offset")]
394 InvalidOffset {},
395}
396
397impl HasErrorCode for HandleGetError {
398 fn code(&self) -> VarInt {
399 match self {
400 HandleGetError::ExportBao {
401 source: ExportBaoError::ClientError { source, .. },
402 ..
403 } => source.code(),
404 HandleGetError::InvalidHashSeq { .. } => ERR_INTERNAL,
405 HandleGetError::InvalidOffset { .. } => ERR_INTERNAL,
406 _ => ERR_INTERNAL,
407 }
408 }
409}
410
411async fn handle_get_impl<W: SendStream>(
415 store: Store,
416 request: GetRequest,
417 writer: &mut ProgressWriter<W>,
418) -> Result<(), HandleGetError> {
419 let hash = request.hash;
420 debug!(%hash, "get received request");
421 let mut hash_seq = None;
422 for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
423 if offset == 0 {
424 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
425 } else {
426 let hash_seq = match &hash_seq {
432 Some(b) => b,
433 None => {
434 let bytes = store.get_bytes(hash).await?;
435 let hs =
436 HashSeq::try_from(bytes).map_err(|_| e!(HandleGetError::InvalidHashSeq))?;
437 hash_seq = Some(hs);
438 hash_seq.as_ref().unwrap()
439 }
440 };
441 let o = usize::try_from(offset - 1).map_err(|_| e!(HandleGetError::InvalidOffset))?;
442 let Some(hash) = hash_seq.get(o) else {
443 break;
444 };
445 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
446 }
447 }
448 writer
449 .inner
450 .sync()
451 .await
452 .map_err(|e| e!(HandleGetError::ExportBao, e.into()))?;
453
454 Ok(())
455}
456
457pub async fn handle_get<R: RecvStream, W: SendStream>(
458 mut pair: StreamPair<R, W>,
459 store: Store,
460 request: GetRequest,
461) -> n0_error::Result<()> {
462 let res = pair.get_request(|| request.clone()).await;
463 let tracker = handle_read_request_result(&mut pair, res).await?;
464 let mut writer = pair.into_writer(tracker).await?;
465 let res = handle_get_impl(store, request, &mut writer).await;
466 handle_write_result(&mut writer, res).await?;
467 Ok(())
468}
469
470#[stack_error(derive, add_meta, from_sources)]
471pub enum HandleGetManyError {
472 #[error(transparent)]
473 ExportBao { source: ExportBaoError },
474}
475
476impl HasErrorCode for HandleGetManyError {
477 fn code(&self) -> VarInt {
478 match self {
479 Self::ExportBao {
480 source: ExportBaoError::ClientError { source, .. },
481 ..
482 } => source.code(),
483 _ => ERR_INTERNAL,
484 }
485 }
486}
487
488async fn handle_get_many_impl<W: SendStream>(
492 store: Store,
493 request: GetManyRequest,
494 writer: &mut ProgressWriter<W>,
495) -> Result<(), HandleGetManyError> {
496 debug!("get_many received request");
497 let request_ranges = request.ranges.iter_infinite();
498 for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
499 if !ranges.is_empty() {
500 send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
501 }
502 }
503 Ok(())
504}
505
506pub async fn handle_get_many<R: RecvStream, W: SendStream>(
507 mut pair: StreamPair<R, W>,
508 store: Store,
509 request: GetManyRequest,
510) -> n0_error::Result<()> {
511 let res = pair.get_many_request(|| request.clone()).await;
512 let tracker = handle_read_request_result(&mut pair, res).await?;
513 let mut writer = pair.into_writer(tracker).await?;
514 let res = handle_get_many_impl(store, request, &mut writer).await;
515 handle_write_result(&mut writer, res).await?;
516 Ok(())
517}
518
519#[stack_error(derive, add_meta, from_sources)]
520pub enum HandlePushError {
521 #[error(transparent)]
522 ExportBao { source: ExportBaoError },
523
524 #[error("Invalid hash sequence")]
525 InvalidHashSeq {},
526
527 #[error(transparent)]
528 Request { source: RequestError },
529}
530
531impl HasErrorCode for HandlePushError {
532 fn code(&self) -> VarInt {
533 match self {
534 Self::ExportBao {
535 source: ExportBaoError::ClientError { source, .. },
536 ..
537 } => source.code(),
538 _ => ERR_INTERNAL,
539 }
540 }
541}
542
543async fn handle_push_impl<R: RecvStream>(
547 store: Store,
548 request: PushRequest,
549 reader: &mut ProgressReader<R>,
550) -> Result<(), HandlePushError> {
551 let hash = request.hash;
552 debug!(%hash, "push received request");
553 let mut request_ranges = request.ranges.iter_infinite();
554 let root_ranges = request_ranges.next().expect("infinite iterator");
555 if !root_ranges.is_empty() {
556 store
558 .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
559 .await?;
560 }
561 if request.ranges.is_blob() {
562 debug!("push request complete");
563 return Ok(());
564 }
565 let hash_seq = store.get_bytes(hash).await?;
567 let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| e!(HandlePushError::InvalidHashSeq))?;
568 for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
569 if child_ranges.is_empty() {
570 continue;
571 }
572 store
573 .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
574 .await?;
575 }
576 Ok(())
577}
578
579pub async fn handle_push<R: RecvStream, W: SendStream>(
580 mut pair: StreamPair<R, W>,
581 store: Store,
582 request: PushRequest,
583) -> n0_error::Result<()> {
584 let res = pair.push_request(|| request.clone()).await;
585 let tracker = handle_read_request_result(&mut pair, res).await?;
586 let mut reader = pair.into_reader(tracker).await?;
587 let res = handle_push_impl(store, request, &mut reader).await;
588 handle_read_result(&mut reader, res).await?;
589 Ok(())
590}
591
592pub(crate) async fn send_blob<W: SendStream>(
594 store: &Store,
595 index: u64,
596 hash: Hash,
597 ranges: ChunkRanges,
598 writer: &mut ProgressWriter<W>,
599) -> ExportBaoResult<()> {
600 store
601 .export_bao(hash, ranges)
602 .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
603 .await
604}
605
606#[stack_error(derive, add_meta, std_sources, from_sources)]
607pub enum HandleObserveError {
608 #[error("observe stream closed")]
609 ObserveStreamClosed {},
610
611 #[error(transparent)]
612 RemoteClosed { source: io::Error },
613}
614
615impl HasErrorCode for HandleObserveError {
616 fn code(&self) -> VarInt {
617 ERR_INTERNAL
618 }
619}
620
621async fn handle_observe_impl<W: SendStream>(
625 store: Store,
626 request: ObserveRequest,
627 writer: &mut ProgressWriter<W>,
628) -> std::result::Result<(), HandleObserveError> {
629 let mut stream = store
630 .observe(request.hash)
631 .stream()
632 .await
633 .map_err(|_| e!(HandleObserveError::ObserveStreamClosed))?;
634 let mut old = stream
635 .next()
636 .await
637 .ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
638 send_observe_item(writer, &old).await?;
640 loop {
642 select! {
643 new = stream.next() => {
644 let new = new.ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
645 let diff = old.diff(&new);
646 if diff.is_empty() {
647 continue;
648 }
649 send_observe_item(writer, &diff).await?;
650 old = new;
651 }
652 _ = writer.inner.stopped() => {
653 debug!("observer closed");
654 break;
655 }
656 }
657 }
658 Ok(())
659}
660
661async fn send_observe_item<W: SendStream>(
662 writer: &mut ProgressWriter<W>,
663 item: &Bitfield,
664) -> io::Result<()> {
665 let item = ObserveItem::from(item);
666 let len = writer.inner.write_length_prefixed(item).await?;
667 writer.context.log_other_write(len);
668 Ok(())
669}
670
671pub async fn handle_observe<R: RecvStream, W: SendStream>(
672 mut pair: StreamPair<R, W>,
673 store: Store,
674 request: ObserveRequest,
675) -> n0_error::Result<()> {
676 let res = pair.observe_request(|| request.clone()).await;
677 let tracker = handle_read_request_result(&mut pair, res).await?;
678 let mut writer = pair.into_writer(tracker).await?;
679 let res = handle_observe_impl(store, request, &mut writer).await;
680 handle_write_result(&mut writer, res).await?;
681 Ok(())
682}
683
684pub struct ProgressReader<R: RecvStream = DefaultReader> {
685 inner: R,
686 context: ReaderContext,
687}
688
689impl<R: RecvStream> ProgressReader<R> {
690 async fn transfer_aborted(&self) {
691 self.context
692 .tracker
693 .transfer_aborted(|| Box::new(self.context.stats()))
694 .await
695 .ok();
696 }
697
698 async fn transfer_completed(&self) {
699 self.context
700 .tracker
701 .transfer_completed(|| Box::new(self.context.stats()))
702 .await
703 .ok();
704 }
705}