Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 51cb087

Browse files
committed
add socket_protector and minor refactor
1 parent f70a90f commit 51cb087

File tree

6 files changed

+181
-120
lines changed

6 files changed

+181
-120
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "rstun"
3-
version = "0.8.2"
3+
version = "0.8.3"
44
edition = "2024"
55
license = "MIT"
66

src/client.rs

Lines changed: 91 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use crate::{
55
pem_util, socket_addr_with_unspecified_ip_port,
66
tcp::{AsyncStream, StreamReceiver, StreamRequest, tcp_tunnel::TcpTunnel},
77
tunnel_event_bus::{
8-
TunnelDescriptor, TunnelEvent, TunnelEventBus, TunnelEventType, TunnelId, TunnelState,
9-
TunnelTraffic,
8+
TunnelDescriptor, TunnelEvent, TunnelEventBus, TunnelEventType, TunnelId, TunnelStat,
9+
TunnelState,
1010
},
1111
tunnel_message::TunnelMessage,
1212
udp::{UdpReceiver, UdpSender, udp_server::UdpServer, udp_tunnel::UdpTunnel},
@@ -70,8 +70,7 @@ struct State {
7070
endpoint: Option<Endpoint>,
7171
connections: HashMap<usize, ConnectionEntry>,
7272
client_state: ClientState,
73-
tunnel_traffic: HashMap<TunnelId, TunnelTraffic>,
74-
tunnel_registry: HashMap<TunnelId, TunnelDescriptor>,
73+
tunnel_stat: TunnelStat,
7574
event_bus: TunnelEventBus,
7675
}
7776

@@ -84,8 +83,7 @@ impl State {
8483
endpoint: None,
8584
connections: HashMap::new(),
8685
client_state: ClientState::Idle,
87-
tunnel_traffic: HashMap::new(),
88-
tunnel_registry: HashMap::new(),
86+
tunnel_stat: TunnelStat::default(),
8987
event_bus: TunnelEventBus::new(),
9088
}
9189
}
@@ -94,7 +92,6 @@ impl State {
9492
#[derive(Clone)]
9593
struct ConnectionEntry {
9694
conn: Connection,
97-
tunnel_id: TunnelId,
9895
}
9996

10097
struct LoginConfig {
@@ -246,7 +243,7 @@ impl Client {
246243
async fn migrate_endpoint(endpoint: &Endpoint) -> Result<()> {
247244
let current_addr = endpoint.local_addr()?;
248245
let new_addr = socket_addr_with_unspecified_ip_port(current_addr.is_ipv6());
249-
let socket = std::net::UdpSocket::bind(new_addr)?;
246+
let socket = Self::bind_client_udp_socket(new_addr)?;
250247
debug!(
251248
"endpoint migration, from_addr:{current_addr}, to_addr:{}",
252249
socket.local_addr()?
@@ -255,6 +252,12 @@ impl Client {
255252
Ok(())
256253
}
257254

255+
fn bind_client_udp_socket(bind_addr: SocketAddr) -> Result<std::net::UdpSocket> {
256+
let socket = std::net::UdpSocket::bind(bind_addr)?;
257+
crate::protect_udp_socket(&socket)?;
258+
Ok(socket)
259+
}
260+
258261
pub async fn start_tcp_server(&self, addr: SocketAddr) -> Result<TcpServer> {
259262
let bind_tcp_server = || async { TcpServer::bind_and_start(addr).await };
260263
let tcp_server = bind_tcp_server
@@ -376,7 +379,6 @@ impl Client {
376379
};
377380
let tunnel_descriptor =
378381
self.build_tunnel_descriptor(tunnel_id.clone(), &tunnel, &connect_input);
379-
self.register_tunnel_descriptor(tunnel_descriptor.clone());
380382

381383
let mut pending_network_request = None;
382384
let mut pending_channel_request = None;
@@ -419,13 +421,8 @@ impl Client {
419421
let ConnectInput::Network = connect_input else {
420422
unreachable!("Network-based tunnel requires ConnectInput::Network");
421423
};
422-
inner_state!(self, connections).insert(
423-
conn.stable_id(),
424-
ConnectionEntry {
425-
conn: conn.clone(),
426-
tunnel_id: tunnel_id.clone(),
427-
},
428-
);
424+
inner_state!(self, connections)
425+
.insert(conn.stable_id(), ConnectionEntry { conn: conn.clone() });
429426

430427
self.handle_network_based_tunnel(
431428
&tunnel_descriptor,
@@ -438,13 +435,8 @@ impl Client {
438435
inner_state!(self, connections).remove(&conn.stable_id());
439436
}
440437
Tunnel::ChannelBased(upstream_type) => {
441-
inner_state!(self, connections).insert(
442-
conn.stable_id(),
443-
ConnectionEntry {
444-
conn: conn.clone(),
445-
tunnel_id: tunnel_id.clone(),
446-
},
447-
);
438+
inner_state!(self, connections)
439+
.insert(conn.stable_id(), ConnectionEntry { conn: conn.clone() });
448440
self.handle_channel_based_tunnel(
449441
&tunnel_descriptor,
450442
&conn,
@@ -490,8 +482,8 @@ impl Client {
490482
let endpoint = if let Some(endpoint) = endpoint {
491483
endpoint
492484
} else {
493-
let mut endpoint = quinn::Endpoint::client(login_cfg.local_addr)?;
494-
endpoint.set_default_client_config(login_cfg.quinn_client_cfg);
485+
let endpoint =
486+
Self::create_client_endpoint(login_cfg.local_addr, login_cfg.quinn_client_cfg)?;
495487
inner_state!(self, endpoint) = Some(endpoint.clone());
496488
endpoint
497489
};
@@ -506,6 +498,21 @@ impl Client {
506498
.await
507499
}
508500

501+
fn create_client_endpoint(
502+
local_addr: SocketAddr,
503+
quinn_client_cfg: quinn::ClientConfig,
504+
) -> Result<Endpoint> {
505+
let socket = Self::bind_client_udp_socket(local_addr)?;
506+
let mut endpoint = quinn::Endpoint::new(
507+
quinn::EndpointConfig::default(),
508+
None,
509+
socket,
510+
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("no async runtime found"))?,
511+
)?;
512+
endpoint.set_default_client_config(quinn_client_cfg);
513+
Ok(endpoint)
514+
}
515+
509516
async fn handle_network_based_tunnel(
510517
&mut self,
511518
tunnel: &TunnelDescriptor,
@@ -557,7 +564,7 @@ impl Client {
557564
}
558565
}
559566

560-
self.add_connection_stats_to_total(tunnel.id.clone(), conn);
567+
self.add_connection_stats_to_total(conn);
561568
}
562569

563570
async fn handle_channel_based_tunnel<S: AsyncStream>(
@@ -599,21 +606,21 @@ impl Client {
599606
_ => unreachable!("Channel-based tunnel requires ConnectInput::Channel"),
600607
}
601608

602-
self.add_connection_stats_to_total(tunnel.id.clone(), conn);
609+
self.add_connection_stats_to_total(conn);
603610
}
604611

605-
fn add_connection_stats_to_total(&self, tunnel_id: TunnelId, conn: &Connection) {
612+
fn add_connection_stats_to_total(&self, conn: &Connection) {
606613
let stats = conn.stats();
607614
let mut state = lock_state(&self.inner_state);
608-
let data = state.tunnel_traffic.entry(tunnel_id).or_default();
609-
data.rx_bytes += stats.udp_rx.bytes;
610-
data.tx_bytes += stats.udp_tx.bytes;
611-
data.rx_dgrams += stats.udp_rx.datagrams;
612-
data.tx_dgrams += stats.udp_tx.datagrams;
613-
data.sent_packets += stats.path.sent_packets;
614-
data.lost_packets += stats.path.lost_packets;
615-
data.lost_bytes += stats.path.lost_bytes;
616-
data.congestion_events += stats.path.congestion_events;
615+
let stat = &mut state.tunnel_stat;
616+
stat.rx_bytes += stats.udp_rx.bytes;
617+
stat.tx_bytes += stats.udp_tx.bytes;
618+
stat.rx_dgrams += stats.udp_rx.datagrams;
619+
stat.tx_dgrams += stats.udp_tx.datagrams;
620+
stat.sent_packets += stats.path.sent_packets;
621+
stat.lost_packets += stats.path.lost_packets;
622+
stat.lost_bytes += stats.path.lost_bytes;
623+
stat.congestion_events += stats.path.congestion_events;
617624
}
618625

619626
async fn prepare_login_config(&self) -> Result<LoginConfig> {
@@ -939,80 +946,65 @@ impl Client {
939946
loop {
940947
interval.tick().await;
941948

942-
let (tunnel_traffic, tunnel_registry, client_state, event_bus) = {
949+
let (stat, client_state, event_bus) = {
943950
let locked_state = lock_state(&state);
944-
let mut tunnel_traffic: HashMap<TunnelId, TunnelTraffic> = HashMap::new();
951+
let mut stat = TunnelStat::default();
945952

946953
for entry in locked_state.connections.values() {
947954
let stats = entry.conn.stats();
948-
let data = tunnel_traffic.entry(entry.tunnel_id.clone()).or_default();
949-
data.rx_bytes += stats.udp_rx.bytes;
950-
data.tx_bytes += stats.udp_tx.bytes;
951-
data.rx_dgrams += stats.udp_rx.datagrams;
952-
data.tx_dgrams += stats.udp_tx.datagrams;
953-
data.sent_packets += stats.path.sent_packets;
954-
data.lost_packets += stats.path.lost_packets;
955-
data.lost_bytes += stats.path.lost_bytes;
956-
data.congestion_events += stats.path.congestion_events;
957-
data.active_conns += 1;
958-
data.rtt_ms += stats.path.rtt.as_millis() as u64;
959-
data.cwnd_bytes += stats.path.cwnd;
960-
data.current_mtu = data.current_mtu.max(stats.path.current_mtu);
955+
stat.rx_bytes += stats.udp_rx.bytes;
956+
stat.tx_bytes += stats.udp_tx.bytes;
957+
stat.rx_dgrams += stats.udp_rx.datagrams;
958+
stat.tx_dgrams += stats.udp_tx.datagrams;
959+
stat.sent_packets += stats.path.sent_packets;
960+
stat.lost_packets += stats.path.lost_packets;
961+
stat.lost_bytes += stats.path.lost_bytes;
962+
stat.congestion_events += stats.path.congestion_events;
963+
stat.active_conns += 1;
964+
stat.rtt_ms = stat.rtt_ms.max(stats.path.rtt.as_millis() as u64);
965+
stat.cwnd_bytes = stat.cwnd_bytes.max(stats.path.cwnd);
966+
stat.current_mtu = stat.current_mtu.max(stats.path.current_mtu);
961967
}
962968

963-
for (tunnel_id, total) in &locked_state.tunnel_traffic {
964-
let data = tunnel_traffic.entry(tunnel_id.clone()).or_default();
965-
data.rx_bytes += total.rx_bytes;
966-
data.tx_bytes += total.tx_bytes;
967-
data.rx_dgrams += total.rx_dgrams;
968-
data.tx_dgrams += total.tx_dgrams;
969-
data.sent_packets += total.sent_packets;
970-
data.lost_packets += total.lost_packets;
971-
data.lost_bytes += total.lost_bytes;
972-
data.congestion_events += total.congestion_events;
973-
}
969+
stat.rx_bytes += locked_state.tunnel_stat.rx_bytes;
970+
stat.tx_bytes += locked_state.tunnel_stat.tx_bytes;
971+
stat.rx_dgrams += locked_state.tunnel_stat.rx_dgrams;
972+
stat.tx_dgrams += locked_state.tunnel_stat.tx_dgrams;
973+
stat.sent_packets += locked_state.tunnel_stat.sent_packets;
974+
stat.lost_packets += locked_state.tunnel_stat.lost_packets;
975+
stat.lost_bytes += locked_state.tunnel_stat.lost_bytes;
976+
stat.congestion_events += locked_state.tunnel_stat.congestion_events;
974977

975978
(
976-
tunnel_traffic,
977-
locked_state.tunnel_registry.clone(),
979+
stat,
978980
locked_state.client_state.clone(),
979981
locked_state.event_bus.clone(),
980982
)
981983
};
982984

983985
let timestamp = chrono::Local::now().format(TIME_FORMAT).to_string();
984-
for (tunnel_id, mut data) in tunnel_traffic {
985-
if let Some(descriptor) = tunnel_registry.get(&tunnel_id) {
986-
if data.active_conns > 0 {
987-
data.rtt_ms /= data.active_conns as u64;
988-
data.cwnd_bytes /= data.active_conns as u64;
989-
}
990-
991-
if log_enabled!(Level::Info) {
992-
info!(
993-
"[{descriptor}] traffic rx_bytes:{}, tx_bytes:{}, rx_dgrams:{}, tx_dgrams:{}, sent_packets:{}, lost_packets:{}, lost_bytes:{}, congestion_events:{}, active_conns:{}, rtt_ms:{}, cwnd_bytes:{}, current_mtu:{}",
994-
data.rx_bytes,
995-
data.tx_bytes,
996-
data.rx_dgrams,
997-
data.tx_dgrams,
998-
data.sent_packets,
999-
data.lost_packets,
1000-
data.lost_bytes,
1001-
data.congestion_events,
1002-
data.active_conns,
1003-
data.rtt_ms,
1004-
data.cwnd_bytes,
1005-
data.current_mtu
1006-
);
1007-
}
1008-
if event_bus.has_listeners() {
1009-
event_bus.post(TunnelEvent::new(
1010-
timestamp.clone(),
1011-
descriptor.clone(),
1012-
TunnelEventType::Traffic(data),
1013-
));
1014-
}
1015-
}
986+
if log_enabled!(Level::Info) {
987+
info!(
988+
"traffic rx_bytes:{}, tx_bytes:{}, rx_dgrams:{}, tx_dgrams:{}, sent_packets:{}, lost_packets:{}, lost_bytes:{}, congestion_events:{}, active_conns:{}, rtt_ms:{}, cwnd_bytes:{}, current_mtu:{}",
989+
stat.rx_bytes,
990+
stat.tx_bytes,
991+
stat.rx_dgrams,
992+
stat.tx_dgrams,
993+
stat.sent_packets,
994+
stat.lost_packets,
995+
stat.lost_bytes,
996+
stat.congestion_events,
997+
stat.active_conns,
998+
stat.rtt_ms,
999+
stat.cwnd_bytes,
1000+
stat.current_mtu
1001+
);
1002+
}
1003+
if event_bus.has_listeners() {
1004+
event_bus.post(TunnelEvent::new_without_tunnel(
1005+
TunnelEventType::Stat(stat),
1006+
timestamp.clone(),
1007+
));
10161008
}
10171009

10181010
if client_state == ClientState::Stopping || client_state == ClientState::Terminated
@@ -1187,17 +1179,17 @@ impl Client {
11871179
fn post_tunnel_log(&self, tunnel: &TunnelDescriptor, msg: &str) {
11881180
info!("[{tunnel}] {msg}");
11891181
self.post_event(TunnelEvent::new(
1182+
TunnelEventType::Log(msg.to_string()),
11901183
chrono::Local::now().format(TIME_FORMAT).to_string(),
11911184
tunnel.clone(),
1192-
TunnelEventType::Log(msg.to_string()),
11931185
));
11941186
}
11951187

11961188
fn post_tunnel_state(&self, tunnel: &TunnelDescriptor, tunnel_state: TunnelState) {
11971189
self.post_event(TunnelEvent::new(
1190+
TunnelEventType::State(tunnel_state),
11981191
chrono::Local::now().format(TIME_FORMAT).to_string(),
11991192
tunnel.clone(),
1200-
TunnelEventType::State(tunnel_state),
12011193
));
12021194
}
12031195

@@ -1236,13 +1228,6 @@ impl Client {
12361228
}
12371229
}
12381230

1239-
fn register_tunnel_descriptor(&self, descriptor: TunnelDescriptor) {
1240-
let mut state = lock_state(&self.inner_state);
1241-
state
1242-
.tunnel_registry
1243-
.insert(descriptor.id.clone(), descriptor);
1244-
}
1245-
12461231
pub fn register_for_events(&self) -> std::sync::mpsc::Receiver<TunnelEvent> {
12471232
let event_bus = lock_state(&self.inner_state).event_bus.clone();
12481233
event_bus.register()

0 commit comments

Comments
 (0)