1use crate::{socket::linux::netlink::sealed::Sealed, *};
2use proc::beta;
3use std::{
4 convert::{TryFrom, TryInto},
5 mem,
6 mem::MaybeUninit,
7};
8
9const ALIGN: usize = 4 - 1;
10
11#[beta]
15pub fn nlmsg_align(size: usize) -> usize {
16 (size + ALIGN) & !ALIGN
17}
18
19#[beta]
26pub fn nlmsg_read<T: Pod>(buf: &mut &[u8]) -> Result<(usize, T)> {
27 let object_size = mem::size_of::<T>();
28 if buf.len() < object_size {
29 return einval();
30 }
31 let mut obj = MaybeUninit::<T>::uninit();
32 unsafe {
33 std::ptr::copy_nonoverlapping(
34 buf.as_ptr(),
35 obj.as_mut_ptr() as *mut u8,
36 object_size,
37 );
38 }
39 let space = nlmsg_align(object_size).min(buf.len());
40 *buf = &buf[space..];
41 let obj = unsafe { obj.assume_init() };
42 Ok((space, obj))
43}
44
45#[beta]
51#[allow(clippy::len_without_is_empty)]
52pub trait NlmsgHeader: Sized {
53 fn len(&self) -> Result<usize>;
58 fn set_len(&mut self, len: usize) -> Result<()>;
63}
64
65mod sealed {
66 pub trait Sealed {}
67}
68
69#[beta]
73pub trait NlmsgHeaderExt: NlmsgHeader + Sealed {
74 fn read<'a>(buf: &mut &'a [u8]) -> Result<(usize, Self, &'a [u8])>
78 where
79 Self: Pod,
80 {
81 nlmsg_read_header(buf)
82 }
83}
84
85impl<T: NlmsgHeader> Sealed for T {
86}
87impl<T: NlmsgHeader> NlmsgHeaderExt for T {
88}
89
90impl NlmsgHeader for () {
91 fn len(&self) -> Result<usize> {
92 Ok(0)
93 }
94
95 fn set_len(&mut self, _len: usize) -> Result<()> {
96 Ok(())
97 }
98}
99
100macro_rules! nlh {
101 ($ty:ident, $field:ident) => {
102 impl NlmsgHeader for c::$ty {
103 fn len(&self) -> Result<usize> {
104 usize::try_from(self.$field).or_else(|_| einval())
105 }
106
107 fn set_len(&mut self, len: usize) -> Result<()> {
108 self.$field = match len.try_into() {
109 Ok(v) => v,
110 Err(_) => return einval(),
111 };
112 Ok(())
113 }
114 }
115 };
116}
117
118nlh!(nlmsghdr, nlmsg_len);
119nlh!(nlattr, nla_len);
120
121fn nlmsg_read_header<'a, H: Pod + NlmsgHeader>(
122 buf: &mut &'a [u8],
123) -> Result<(usize, H, &'a [u8])> {
124 let header_space = nlmsg_align(mem::size_of::<H>());
125 let hdr: H = {
126 let mut buf = *buf;
127 nlmsg_read(&mut buf)?.1
128 };
129 let len = hdr.len()?;
130 if len < header_space {
131 return einval();
132 }
133 if buf.len() < len {
134 return einval();
135 }
136 if usize::max_value() - len < ALIGN {
137 return einval();
138 }
139 let space = nlmsg_align(len).min(buf.len());
140 let data = &buf[header_space..len];
141 *buf = &buf[space..];
142 Ok((space, hdr, data))
143}
144
145#[beta]
149pub struct NlmsgWriter<'a, H: NlmsgHeader = ()> {
150 buf: &'a mut [MaybeUninit<u8>],
151 header: H,
152 len: usize,
153 parent_len: Option<&'a mut usize>,
154}
155
156impl<'a, H: NlmsgHeader> NlmsgWriter<'a, H> {
157 pub fn new<T: Pod + ?Sized>(buf: &'a mut T, header: H) -> Result<Self> {
159 let buf = unsafe { as_maybe_uninit_bytes_mut2(buf) };
160 Self::new2(buf, None, header)
161 }
162
163 fn new2<'b, H2: NlmsgHeader>(
164 buf: &'b mut [MaybeUninit<u8>],
165 parent_len: Option<&'b mut usize>,
166 header: H2,
167 ) -> Result<NlmsgWriter<'b, H2>> {
168 let size = mem::size_of::<H2>();
169 if buf.len() < size {
170 return einval();
171 }
172 Ok(NlmsgWriter {
173 buf,
174 header,
175 len: size,
176 parent_len,
177 })
178 }
179
180 pub fn write<T: ?Sized>(&mut self, data: &T) -> Result<()> {
190 let aligned_len = nlmsg_align(self.len);
191 {
192 if aligned_len > self.buf.len() {
193 return einval();
194 }
195 let buf = &mut self.buf[aligned_len..];
196 let data_size = mem::size_of_val(data);
197 if buf.len() < data_size {
198 return einval();
199 }
200 unsafe {
201 let ptr = buf.as_mut_ptr();
202 ptr.copy_from_nonoverlapping(data as *const _ as *const _, data_size);
203 black_box(ptr);
204 }
205 }
206 self.len = aligned_len + mem::size_of_val(data);
207 Ok(())
208 }
209
210 pub fn nest<H2: NlmsgHeader>(&mut self, header: H2) -> Result<NlmsgWriter<H2>> {
219 let aligned_len = nlmsg_align(self.len);
220 if aligned_len >= self.buf.len() {
221 return einval();
222 }
223 Self::new2(&mut self.buf[aligned_len..], Some(&mut self.len), header)
224 }
225
226 fn finalize_mut(&mut self) -> Result<usize> {
227 self.header.set_len(self.len)?;
228 let ptr = self.buf.as_mut_ptr();
229 unsafe {
230 ptr.copy_from_nonoverlapping(
231 &self.header as *const _ as *const _,
232 mem::size_of::<H>(),
233 );
234 black_box(ptr);
235 }
236 if let Some(parent_len) = &mut self.parent_len {
237 **parent_len = nlmsg_align(**parent_len) + self.len;
238 }
239 Ok(self.len)
240 }
241
242 pub fn finalize(mut self) -> Result<&'a mut [u8]> {
246 let len = self.finalize_mut()?;
247 let buf = self.buf.as_mut_ptr();
248 mem::forget(self);
249 unsafe { Ok(std::slice::from_raw_parts_mut(buf, len).slice_assume_init_mut()) }
250 }
251}
252
253impl<'a, H: NlmsgHeader> Drop for NlmsgWriter<'a, H> {
254 fn drop(&mut self) {
255 self.finalize_mut().expect("could not finalize header");
256 }
257}
258
259#[cfg(test)]
260mod test {
261 use crate::*;
262 use std::mem::MaybeUninit;
263
264 #[test]
265 fn test_client_to_client() -> Result<()> {
266 let s1 = socket(c::AF_NETLINK, c::SOCK_RAW, c::NETLINK_USERSOCK)?;
267 let s2 = socket(c::AF_NETLINK, c::SOCK_RAW, c::NETLINK_USERSOCK)?;
268 let mut addr = c::sockaddr_nl {
269 nl_family: c::AF_NETLINK as _,
270 nl_pad: 0,
271 nl_pid: 0,
272 nl_groups: 0,
273 };
274 bind(*s1, &addr)?;
275 getsockname(*s1, &mut addr)?;
276 let mut buf = [MaybeUninit::<u8>::uninit(); 128];
277 let mut writer = NlmsgWriter::new(
278 &mut buf[..],
279 c::nlmsghdr {
280 nlmsg_len: 0,
281 nlmsg_type: 1,
282 nlmsg_flags: 2,
283 nlmsg_seq: 3,
284 nlmsg_pid: 4,
285 },
286 )?;
287 {
288 let mut attr = writer.nest(c::nlattr {
289 nla_len: 0,
290 nla_type: 5,
291 })?;
292 {
293 let mut attr = attr.nest(c::nlattr {
294 nla_len: 0,
295 nla_type: 6,
296 })?;
297 attr.write(&1u8)?;
298 }
299 {
300 let mut attr = attr.nest(c::nlattr {
301 nla_len: 0,
302 nla_type: 7,
303 })?;
304 attr.write("hello world")?;
305 }
306 }
307 let msg = writer.finalize()?;
308 sendto(*s2, msg, 0, &addr)?;
309 let mut reader = &*recv(*s1, &mut buf[..], 0)?;
310 let (_, nlmsghdr, mut payload) = c::nlmsghdr::read(&mut reader)?;
311 assert_eq!(nlmsghdr.nlmsg_type, 1);
312 assert_eq!(nlmsghdr.nlmsg_flags, 2);
313 assert_eq!(nlmsghdr.nlmsg_seq, 3);
314 assert_eq!(nlmsghdr.nlmsg_pid, 4);
315 {
316 let (_, outer_attr, mut payload) = c::nlattr::read(&mut payload)?;
317 assert_eq!(outer_attr.nla_type, 5);
318 {
319 let (_, inner_attr, payload) = c::nlattr::read(&mut payload)?;
320 assert_eq!(inner_attr.nla_type, 6);
321 assert_eq!(pod_read::<u8, _>(payload)?, 1);
322 }
323 {
324 let (_, inner_attr, payload) = c::nlattr::read(&mut payload)?;
325 assert_eq!(inner_attr.nla_type, 7);
326 assert_eq!(payload, b"hello world");
327 }
328 assert!(payload.is_empty());
329 }
330 assert!(payload.is_empty());
331 assert!(reader.is_empty());
332 Ok(())
333 }
334
335 #[test]
336 fn test_rt_netlink() -> Result<()> {
337 let socket = socket(c::AF_NETLINK, c::SOCK_RAW, c::NETLINK_ROUTE)?;
338 let addr = c::sockaddr_nl {
339 nl_family: c::AF_NETLINK as _,
340 nl_pad: 0,
341 nl_pid: 0,
342 nl_groups: 0,
343 };
344 bind(*socket, &addr)?;
345 let mut buf = [MaybeUninit::<u8>::uninit(); 32 * 1024];
346 let mut writer = NlmsgWriter::new(
347 &mut buf[..],
348 c::nlmsghdr {
349 nlmsg_len: 0,
350 nlmsg_type: c::RTM_GETLINK,
351 nlmsg_flags: (c::NLM_F_REQUEST | c::NLM_F_DUMP) as _,
352 nlmsg_seq: 0,
353 nlmsg_pid: 0,
354 },
355 )?;
356 writer.write(&c::ifinfomsg {
357 ifi_family: c::AF_PACKET as _,
358 ifi_type: 0,
359 ifi_index: 0,
360 ifi_flags: 0,
361 ifi_change: 0,
362 })?;
363 {
364 let mut attr = writer.nest(c::nlattr {
365 nla_len: 0,
366 nla_type: c::IFLA_EXT_MASK,
367 })?;
368 attr.write(&1u32)?;
369 }
370 let msg = writer.finalize()?;
371 send(*socket, msg, 0)?;
372 let mut found_loopback = false;
373 'outer: loop {
374 let mut reader = &*recv(*socket, &mut buf[..], c::MSG_TRUNC)?;
375 while reader.len() > 0 {
376 let (_, header, mut payload) = c::nlmsghdr::read(&mut reader)?;
377 if header.nlmsg_type == c::NLMSG_DONE as _ {
378 break 'outer;
379 }
380 assert_eq!(header.nlmsg_type, c::RTM_NEWLINK);
381 let (_, ifi) = nlmsg_read::<c::ifinfomsg>(&mut payload)?;
382 let is_loopback = ifi.ifi_type == c::ARPHRD_LOOPBACK;
383 if is_loopback {
384 found_loopback = true;
385 assert_eq!(ifi.ifi_family, c::AF_UNSPEC as c::c_uchar);
386 assert_ne!(ifi.ifi_flags & c::IFF_UP as c::c_uint, 0);
387 assert_ne!(ifi.ifi_flags & c::IFF_LOOPBACK as c::c_uint, 0);
388 }
389 let mut found_name = false;
390 while payload.len() > 0 {
391 let (_, header, payload) = c::nlattr::read(&mut payload)?;
392 if header.nla_type == c::IFLA_IFNAME {
393 found_name = true;
394 if is_loopback {
395 assert_eq!(payload, b"lo\0");
396 }
397 }
398 }
399 assert!(found_name);
400 if header.nlmsg_flags & c::NLM_F_MULTI as u16 == 0 {
401 break 'outer;
402 }
403 }
404 }
405 assert!(found_loopback);
406 Ok(())
407 }
408}