1use std::{fmt::Debug, io, ops::Deref};
2
3use iroh::endpoint::VarInt;
4use irpc::{
5 channel::{mpsc, none::NoSender, oneshot},
6 rpc_requests, Channels, WithChannels,
7};
8use n0_error::{e, stack_error};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 protocol::{
13 GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT,
14 ERR_PERMISSION,
15 },
16 provider::{events::irpc_ext::IrpcClientExt, TransferStats},
17 Hash,
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22#[repr(u8)]
23pub enum ConnectMode {
24 #[default]
26 None,
27 Notify,
29 Intercept,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35#[repr(u8)]
36pub enum ObserveMode {
37 #[default]
39 None,
40 Notify,
42 Intercept,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
48#[repr(u8)]
49pub enum RequestMode {
50 #[default]
52 None,
53 Notify,
55 Intercept,
57 NotifyLog,
59 InterceptLog,
62 Disabled,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
71#[repr(u8)]
72pub enum ThrottleMode {
73 #[default]
75 None,
76 Intercept,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
81pub enum AbortReason {
82 RateLimited,
84 Permission,
86}
87
88#[stack_error(derive, add_meta, from_sources)]
90pub enum ProgressError {
91 #[error("limit")]
92 Limit {},
93 #[error("permission")]
94 Permission {},
95 #[error(transparent)]
96 Internal { source: irpc::Error },
97}
98
99impl From<ProgressError> for io::Error {
100 fn from(value: ProgressError) -> Self {
101 match value {
102 ProgressError::Limit { .. } => io::ErrorKind::QuotaExceeded.into(),
103 ProgressError::Permission { .. } => io::ErrorKind::PermissionDenied.into(),
104 ProgressError::Internal { source, .. } => source.into(),
105 }
106 }
107}
108
109pub trait HasErrorCode {
110 fn code(&self) -> VarInt;
111}
112
113impl HasErrorCode for ProgressError {
114 fn code(&self) -> VarInt {
115 match self {
116 ProgressError::Limit { .. } => ERR_LIMIT,
117 ProgressError::Permission { .. } => ERR_PERMISSION,
118 ProgressError::Internal { .. } => ERR_INTERNAL,
119 }
120 }
121}
122
123impl ProgressError {
124 pub fn reason(&self) -> &'static [u8] {
125 match self {
126 ProgressError::Limit { .. } => b"limit",
127 ProgressError::Permission { .. } => b"permission",
128 ProgressError::Internal { .. } => b"internal",
129 }
130 }
131}
132
133impl From<AbortReason> for ProgressError {
134 fn from(value: AbortReason) -> Self {
135 match value {
136 AbortReason::RateLimited => n0_error::e!(ProgressError::Limit),
137 AbortReason::Permission => n0_error::e!(ProgressError::Permission),
138 }
139 }
140}
141
142impl From<irpc::channel::mpsc::RecvError> for ProgressError {
143 fn from(value: irpc::channel::mpsc::RecvError) -> Self {
144 n0_error::e!(ProgressError::Internal, value.into())
145 }
146}
147
148impl From<irpc::channel::oneshot::RecvError> for ProgressError {
149 fn from(value: irpc::channel::oneshot::RecvError) -> Self {
150 n0_error::e!(ProgressError::Internal, value.into())
151 }
152}
153
154impl From<irpc::channel::SendError> for ProgressError {
155 fn from(value: irpc::channel::SendError) -> Self {
156 n0_error::e!(ProgressError::Internal, value.into())
157 }
158}
159
160pub type EventResult = Result<(), AbortReason>;
161pub type ClientResult = Result<(), ProgressError>;
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
168pub struct EventMask {
169 pub connected: ConnectMode,
171 pub get: RequestMode,
173 pub get_many: RequestMode,
175 pub push: RequestMode,
177 pub observe: ObserveMode,
179 pub throttle: ThrottleMode,
181}
182
183impl Default for EventMask {
184 fn default() -> Self {
185 Self::DEFAULT
186 }
187}
188
189impl EventMask {
190 pub const DEFAULT: Self = Self {
192 connected: ConnectMode::None,
193 get: RequestMode::None,
194 get_many: RequestMode::None,
195 push: RequestMode::Disabled,
196 throttle: ThrottleMode::None,
197 observe: ObserveMode::None,
198 };
199
200 pub const ALL_READONLY: Self = Self {
206 connected: ConnectMode::Intercept,
207 get: RequestMode::InterceptLog,
208 get_many: RequestMode::InterceptLog,
209 push: RequestMode::Disabled,
210 throttle: ThrottleMode::Intercept,
211 observe: ObserveMode::Intercept,
212 };
213}
214
215#[derive(Debug, Serialize, Deserialize)]
217pub struct Notify<T>(T);
218
219impl<T> Deref for Notify<T> {
220 type Target = T;
221
222 fn deref(&self) -> &Self::Target {
223 &self.0
224 }
225}
226
227#[derive(Debug, Default, Clone)]
228pub struct EventSender {
229 mask: EventMask,
230 inner: Option<irpc::Client<ProviderProto>>,
231}
232
233#[derive(Debug, Default)]
234enum RequestUpdates {
235 #[default]
237 None,
238 Active(mpsc::Sender<RequestUpdate>),
240 Disabled(#[allow(dead_code)] mpsc::Sender<RequestUpdate>),
243}
244
245#[derive(Debug)]
246pub struct RequestTracker {
247 updates: RequestUpdates,
248 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
249}
250
251impl RequestTracker {
252 fn new(
253 updates: RequestUpdates,
254 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
255 ) -> Self {
256 Self { updates, throttle }
257 }
258
259 pub const NONE: Self = Self {
261 updates: RequestUpdates::None,
262 throttle: None,
263 };
264
265 pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> {
267 if let RequestUpdates::Active(tx) = &self.updates {
268 tx.send(
269 TransferStarted {
270 index,
271 hash: *hash,
272 size,
273 }
274 .into(),
275 )
276 .await?;
277 }
278 Ok(())
279 }
280
281 pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult {
283 if let RequestUpdates::Active(tx) = &mut self.updates {
284 tx.try_send(TransferProgress { end_offset }.into()).await?;
285 }
286 if let Some((throttle, connection_id, request_id)) = &self.throttle {
287 throttle
288 .rpc(Throttle {
289 connection_id: *connection_id,
290 request_id: *request_id,
291 size: len,
292 })
293 .await??;
294 }
295 Ok(())
296 }
297
298 pub async fn transfer_completed(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
300 if let RequestUpdates::Active(tx) = &self.updates {
301 tx.send(TransferCompleted { stats: f() }.into()).await?;
302 }
303 Ok(())
304 }
305
306 pub async fn transfer_aborted(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
308 if let RequestUpdates::Active(tx) = &self.updates {
309 tx.send(TransferAborted { stats: f() }.into()).await?;
310 }
311 Ok(())
312 }
313}
314
315impl EventSender {
320 pub const DEFAULT: Self = Self {
322 mask: EventMask::DEFAULT,
323 inner: None,
324 };
325
326 pub fn new(client: tokio::sync::mpsc::Sender<ProviderMessage>, mask: EventMask) -> Self {
327 Self {
328 mask,
329 inner: Some(irpc::Client::from(client)),
330 }
331 }
332
333 pub fn channel(
334 capacity: usize,
335 mask: EventMask,
336 ) -> (Self, tokio::sync::mpsc::Receiver<ProviderMessage>) {
337 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
338 (Self::new(tx, mask), rx)
339 }
340
341 pub fn tracing(&self, mask: EventMask) -> Self {
343 use tracing::trace;
344 let (tx, mut rx) = tokio::sync::mpsc::channel(32);
345 n0_future::task::spawn(async move {
346 fn log_request_events(
347 mut rx: irpc::channel::mpsc::Receiver<RequestUpdate>,
348 connection_id: u64,
349 request_id: u64,
350 ) {
351 n0_future::task::spawn(async move {
352 while let Ok(Some(update)) = rx.recv().await {
353 trace!(%connection_id, %request_id, "{update:?}");
354 }
355 });
356 }
357 while let Some(msg) = rx.recv().await {
358 match msg {
359 ProviderMessage::ClientConnected(msg) => {
360 trace!("{:?}", msg.inner);
361 msg.tx.send(Ok(())).await.ok();
362 }
363 ProviderMessage::ClientConnectedNotify(msg) => {
364 trace!("{:?}", msg.inner);
365 }
366 ProviderMessage::ConnectionClosed(msg) => {
367 trace!("{:?}", msg.inner);
368 }
369 ProviderMessage::GetRequestReceived(msg) => {
370 trace!("{:?}", msg.inner);
371 msg.tx.send(Ok(())).await.ok();
372 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
373 }
374 ProviderMessage::GetRequestReceivedNotify(msg) => {
375 trace!("{:?}", msg.inner);
376 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
377 }
378 ProviderMessage::GetManyRequestReceived(msg) => {
379 trace!("{:?}", msg.inner);
380 msg.tx.send(Ok(())).await.ok();
381 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
382 }
383 ProviderMessage::GetManyRequestReceivedNotify(msg) => {
384 trace!("{:?}", msg.inner);
385 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
386 }
387 ProviderMessage::PushRequestReceived(msg) => {
388 trace!("{:?}", msg.inner);
389 msg.tx.send(Ok(())).await.ok();
390 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
391 }
392 ProviderMessage::PushRequestReceivedNotify(msg) => {
393 trace!("{:?}", msg.inner);
394 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
395 }
396 ProviderMessage::ObserveRequestReceived(msg) => {
397 trace!("{:?}", msg.inner);
398 msg.tx.send(Ok(())).await.ok();
399 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
400 }
401 ProviderMessage::ObserveRequestReceivedNotify(msg) => {
402 trace!("{:?}", msg.inner);
403 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
404 }
405 ProviderMessage::Throttle(msg) => {
406 trace!("{:?}", msg.inner);
407 msg.tx.send(Ok(())).await.ok();
408 }
409 }
410 }
411 });
412 Self {
413 mask,
414 inner: Some(irpc::Client::from(tx)),
415 }
416 }
417
418 pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult {
420 if let Some(client) = &self.inner {
421 match self.mask.connected {
422 ConnectMode::None => {}
423 ConnectMode::Notify => client.notify(Notify(f())).await?,
424 ConnectMode::Intercept => client.rpc(f()).await??,
425 }
426 };
427 Ok(())
428 }
429
430 pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult {
432 if let Some(client) = &self.inner {
433 client.notify(f()).await?;
434 };
435 Ok(())
436 }
437
438 pub(crate) async fn request<Req>(
442 &self,
443 f: impl FnOnce() -> Req,
444 connection_id: u64,
445 request_id: u64,
446 ) -> Result<RequestTracker, ProgressError>
447 where
448 ProviderProto: From<RequestReceived<Req>>,
449 ProviderMessage: From<WithChannels<RequestReceived<Req>, ProviderProto>>,
450 RequestReceived<Req>: Channels<
451 ProviderProto,
452 Tx = oneshot::Sender<EventResult>,
453 Rx = mpsc::Receiver<RequestUpdate>,
454 >,
455 ProviderProto: From<Notify<RequestReceived<Req>>>,
456 ProviderMessage: From<WithChannels<Notify<RequestReceived<Req>>, ProviderProto>>,
457 Notify<RequestReceived<Req>>:
458 Channels<ProviderProto, Tx = NoSender, Rx = mpsc::Receiver<RequestUpdate>>,
459 {
460 let client = self.inner.as_ref();
461 Ok(self.create_tracker((
462 match self.mask.get {
463 RequestMode::None => RequestUpdates::None,
464 RequestMode::Notify if client.is_some() => {
465 let msg = RequestReceived {
466 request: f(),
467 connection_id,
468 request_id,
469 };
470 RequestUpdates::Disabled(
471 client.unwrap().notify_streaming(Notify(msg), 32).await?,
472 )
473 }
474 RequestMode::Intercept if client.is_some() => {
475 let msg = RequestReceived {
476 request: f(),
477 connection_id,
478 request_id,
479 };
480 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
481 rx.await??;
483 RequestUpdates::Disabled(tx)
484 }
485 RequestMode::NotifyLog if client.is_some() => {
486 let msg = RequestReceived {
487 request: f(),
488 connection_id,
489 request_id,
490 };
491 RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?)
492 }
493 RequestMode::InterceptLog if client.is_some() => {
494 let msg = RequestReceived {
495 request: f(),
496 connection_id,
497 request_id,
498 };
499 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
500 rx.await??;
502 RequestUpdates::Active(tx)
503 }
504 RequestMode::Disabled => {
505 return Err(e!(ProgressError::Permission));
506 }
507 _ => RequestUpdates::None,
508 },
509 connection_id,
510 request_id,
511 )))
512 }
513
514 fn create_tracker(
515 &self,
516 (updates, connection_id, request_id): (RequestUpdates, u64, u64),
517 ) -> RequestTracker {
518 let throttle = match self.mask.throttle {
519 ThrottleMode::None => None,
520 ThrottleMode::Intercept => self
521 .inner
522 .clone()
523 .map(|client| (client, connection_id, request_id)),
524 };
525 RequestTracker::new(updates, throttle)
526 }
527}
528
529#[rpc_requests(message = ProviderMessage, rpc_feature = "rpc")]
530#[derive(Debug, Serialize, Deserialize)]
531pub enum ProviderProto {
532 #[rpc(tx = oneshot::Sender<EventResult>)]
534 ClientConnected(ClientConnected),
535
536 #[rpc(tx = NoSender)]
538 ClientConnectedNotify(Notify<ClientConnected>),
539
540 #[rpc(tx = NoSender)]
542 ConnectionClosed(ConnectionClosed),
543
544 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
546 GetRequestReceived(RequestReceived<GetRequest>),
547
548 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
550 GetRequestReceivedNotify(Notify<RequestReceived<GetRequest>>),
551
552 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
554 GetManyRequestReceived(RequestReceived<GetManyRequest>),
555
556 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
558 GetManyRequestReceivedNotify(Notify<RequestReceived<GetManyRequest>>),
559
560 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
562 PushRequestReceived(RequestReceived<PushRequest>),
563
564 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
566 PushRequestReceivedNotify(Notify<RequestReceived<PushRequest>>),
567
568 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
570 ObserveRequestReceived(RequestReceived<ObserveRequest>),
571
572 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
574 ObserveRequestReceivedNotify(Notify<RequestReceived<ObserveRequest>>),
575
576 #[rpc(tx = oneshot::Sender<EventResult>)]
578 Throttle(Throttle),
579}
580
581mod proto {
582 use iroh::EndpointId;
583 use serde::{Deserialize, Serialize};
584
585 use crate::{provider::TransferStats, Hash};
586
587 #[derive(Debug, Serialize, Deserialize)]
588 pub struct ClientConnected {
589 pub connection_id: u64,
590 pub endpoint_id: Option<EndpointId>,
591 }
592
593 #[derive(Debug, Serialize, Deserialize)]
594 pub struct ConnectionClosed {
595 pub connection_id: u64,
596 }
597
598 #[derive(Debug, Serialize, Deserialize)]
600 pub struct RequestReceived<R> {
601 pub connection_id: u64,
603 pub request_id: u64,
605 pub request: R,
607 }
608
609 #[derive(Debug, Serialize, Deserialize)]
611 pub struct Throttle {
612 pub connection_id: u64,
614 pub request_id: u64,
616 pub size: u64,
618 }
619
620 #[derive(Debug, Serialize, Deserialize)]
621 pub struct TransferProgress {
622 pub end_offset: u64,
624 }
625
626 #[derive(Debug, Serialize, Deserialize)]
627 pub struct TransferStarted {
628 pub index: u64,
629 pub hash: Hash,
630 pub size: u64,
631 }
632
633 #[derive(Debug, Serialize, Deserialize)]
634 pub struct TransferCompleted {
635 pub stats: Box<TransferStats>,
636 }
637
638 #[derive(Debug, Serialize, Deserialize)]
639 pub struct TransferAborted {
640 pub stats: Box<TransferStats>,
641 }
642
643 #[derive(Debug, Serialize, Deserialize, derive_more::From)]
645 pub enum RequestUpdate {
646 Started(TransferStarted),
648 Progress(TransferProgress),
650 Completed(TransferCompleted),
652 Aborted(TransferAborted),
654 }
655}
656pub use proto::*;
657
658mod irpc_ext {
659 use std::future::Future;
660
661 use irpc::{
662 channel::{mpsc, none::NoSender},
663 Channels, RpcMessage, Service, WithChannels,
664 };
665
666 pub trait IrpcClientExt<S: Service> {
667 fn notify_streaming<Req, Update>(
668 &self,
669 msg: Req,
670 local_update_cap: usize,
671 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
672 where
673 S: From<Req>,
674 S::Message: From<WithChannels<Req, S>>,
675 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
676 Update: RpcMessage;
677 }
678
679 impl<S: Service> IrpcClientExt<S> for irpc::Client<S> {
680 fn notify_streaming<Req, Update>(
681 &self,
682 msg: Req,
683 local_update_cap: usize,
684 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
685 where
686 S: From<Req>,
687 S::Message: From<WithChannels<Req, S>>,
688 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
689 Update: RpcMessage,
690 {
691 let client = self.clone();
692 async move {
693 let request = client.request().await?;
694 match request {
695 irpc::Request::Local(local) => {
696 let (req_tx, req_rx) = mpsc::channel(local_update_cap);
697 local
698 .send((msg, NoSender, req_rx))
699 .await
700 .map_err(irpc::Error::from)?;
701 Ok(req_tx)
702 }
703 #[cfg(feature = "rpc")]
704 irpc::Request::Remote(remote) => {
705 let (s, _) = remote.write(msg).await?;
706 Ok(s.into())
707 }
708 #[cfg(not(feature = "rpc"))]
709 irpc::Request::Remote(_) => {
710 unreachable!()
711 }
712 }
713 }
714 }
715 }
716}