1use alloc::vec::Vec;
8use core::{
9 fmt::{self, Debug},
10 hash::{BuildHasher, Hash},
11};
12
13use bevy_platform::{collections::HashSet, hash::FixedHasher};
14use indexmap::IndexMap;
15use smallvec::SmallVec;
16
17use Direction::{Incoming, Outgoing};
18
19pub trait GraphNodeId: Copy + Eq + Hash + Ord + Debug {
24 type Adjacent: Copy + Debug + From<(Self, Direction)> + Into<(Self, Direction)>;
27 type Edge: Copy + Eq + Hash + Debug + From<(Self, Self)> + Into<(Self, Self)>;
30
31 fn kind(&self) -> &'static str;
36}
37
38pub type UnGraph<N, S = FixedHasher> = Graph<false, N, S>;
43
44pub type DiGraph<N, S = FixedHasher> = Graph<true, N, S>;
49
50#[derive(Clone)]
68pub struct Graph<const DIRECTED: bool, N: GraphNodeId, S = FixedHasher>
69where
70 S: BuildHasher,
71{
72 nodes: IndexMap<N, Vec<N::Adjacent>, S>,
73 edges: HashSet<N::Edge, S>,
74}
75
76impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Debug for Graph<DIRECTED, N, S> {
77 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
78 self.nodes.fmt(f)
79 }
80}
81
82impl<const DIRECTED: bool, N: GraphNodeId, S: BuildHasher> Graph<DIRECTED, N, S> {
83 pub fn with_capacity(nodes: usize, edges: usize) -> Self
85 where
86 S: Default,
87 {
88 Self {
89 nodes: IndexMap::with_capacity_and_hasher(nodes, S::default()),
90 edges: HashSet::with_capacity_and_hasher(edges, S::default()),
91 }
92 }
93
94 #[inline]
96 fn edge_key(a: N, b: N) -> N::Edge {
97 let (a, b) = if DIRECTED || a <= b { (a, b) } else { (b, a) };
98
99 N::Edge::from((a, b))
100 }
101
102 pub fn node_count(&self) -> usize {
104 self.nodes.len()
105 }
106
107 pub fn edge_count(&self) -> usize {
109 self.edges.len()
110 }
111
112 pub fn add_node(&mut self, n: N) {
114 self.nodes.entry(n).or_default();
115 }
116
117 pub fn remove_node(&mut self, n: N) {
121 let Some(links) = self.nodes.swap_remove(&n) else {
122 return;
123 };
124
125 let links = links.into_iter().map(N::Adjacent::into);
126
127 for (succ, dir) in links {
128 let edge = if dir == Outgoing {
129 Self::edge_key(n, succ)
130 } else {
131 Self::edge_key(succ, n)
132 };
133 self.remove_single_edge(succ, n, dir.opposite());
135 self.edges.remove(&edge);
137 }
138 }
139
140 pub fn contains_node(&self, n: N) -> bool {
142 self.nodes.contains_key(&n)
143 }
144
145 pub fn add_edge(&mut self, a: N, b: N) {
150 if self.edges.insert(Self::edge_key(a, b)) {
151 self.nodes
153 .entry(a)
154 .or_insert_with(|| Vec::with_capacity(1))
155 .push(N::Adjacent::from((b, Outgoing)));
156 if a != b {
157 self.nodes
159 .entry(b)
160 .or_insert_with(|| Vec::with_capacity(1))
161 .push(N::Adjacent::from((a, Incoming)));
162 }
163 }
164 }
165
166 fn remove_single_edge(&mut self, a: N, b: N, dir: Direction) -> bool {
170 let Some(sus) = self.nodes.get_mut(&a) else {
171 return false;
172 };
173
174 let Some(index) = sus
175 .iter()
176 .copied()
177 .map(N::Adjacent::into)
178 .position(|elt| (DIRECTED && elt == (b, dir)) || (!DIRECTED && elt.0 == b))
179 else {
180 return false;
181 };
182
183 sus.swap_remove(index);
184 true
185 }
186
187 pub fn remove_edge(&mut self, a: N, b: N) -> bool {
191 let exist1 = self.remove_single_edge(a, b, Outgoing);
192 let exist2 = if a != b {
193 self.remove_single_edge(b, a, Incoming)
194 } else {
195 exist1
196 };
197 let weight = self.edges.remove(&Self::edge_key(a, b));
198 debug_assert!(exist1 == exist2 && exist1 == weight);
199 weight
200 }
201
202 pub fn contains_edge(&self, a: N, b: N) -> bool {
204 self.edges.contains(&Self::edge_key(a, b))
205 }
206
207 pub fn nodes(&self) -> impl DoubleEndedIterator<Item = N> + ExactSizeIterator<Item = N> + '_ {
209 self.nodes.keys().copied()
210 }
211
212 pub fn neighbors(&self, a: N) -> impl DoubleEndedIterator<Item = N> + '_ {
214 let iter = match self.nodes.get(&a) {
215 Some(neigh) => neigh.iter(),
216 None => [].iter(),
217 };
218
219 iter.copied()
220 .map(N::Adjacent::into)
221 .filter_map(|(n, dir)| (!DIRECTED || dir == Outgoing).then_some(n))
222 }
223
224 pub fn neighbors_directed(
228 &self,
229 a: N,
230 dir: Direction,
231 ) -> impl DoubleEndedIterator<Item = N> + '_ {
232 let iter = match self.nodes.get(&a) {
233 Some(neigh) => neigh.iter(),
234 None => [].iter(),
235 };
236
237 iter.copied()
238 .map(N::Adjacent::into)
239 .filter_map(move |(n, d)| (!DIRECTED || d == dir || n == a).then_some(n))
240 }
241
242 pub fn edges(&self, a: N) -> impl DoubleEndedIterator<Item = (N, N)> + '_ {
245 self.neighbors(a)
246 .map(move |b| match self.edges.get(&Self::edge_key(a, b)) {
247 None => unreachable!(),
248 Some(_) => (a, b),
249 })
250 }
251
252 pub fn edges_directed(
255 &self,
256 a: N,
257 dir: Direction,
258 ) -> impl DoubleEndedIterator<Item = (N, N)> + '_ {
259 self.neighbors_directed(a, dir).map(move |b| {
260 let (a, b) = if dir == Incoming { (b, a) } else { (a, b) };
261
262 match self.edges.get(&Self::edge_key(a, b)) {
263 None => unreachable!(),
264 Some(_) => (a, b),
265 }
266 })
267 }
268
269 pub fn all_edges(&self) -> impl ExactSizeIterator<Item = (N, N)> + '_ {
271 self.edges.iter().copied().map(N::Edge::into)
272 }
273
274 pub(crate) fn to_index(&self, ix: N) -> usize {
275 self.nodes.get_index_of(&ix).unwrap()
276 }
277
278 pub fn try_into<T: GraphNodeId + TryFrom<N>>(self) -> Result<Graph<DIRECTED, T, S>, T::Error>
288 where
289 S: Default,
290 {
291 fn try_convert_node<N: GraphNodeId, T: GraphNodeId + TryFrom<N>>(
293 (key, adj): (N, Vec<N::Adjacent>),
294 ) -> Result<(T, Vec<T::Adjacent>), T::Error> {
295 let key = key.try_into()?;
296 let adj = adj
297 .into_iter()
298 .map(|node| {
299 let (id, dir) = node.into();
300 Ok(T::Adjacent::from((id.try_into()?, dir)))
301 })
302 .collect::<Result<_, T::Error>>()?;
303 Ok((key, adj))
304 }
305 fn try_convert_edge<N: GraphNodeId, T: GraphNodeId + TryFrom<N>>(
307 edge: N::Edge,
308 ) -> Result<T::Edge, T::Error> {
309 let (a, b) = edge.into();
310 Ok(T::Edge::from((a.try_into()?, b.try_into()?)))
311 }
312
313 let nodes = self
314 .nodes
315 .into_iter()
316 .map(try_convert_node::<N, T>)
317 .collect::<Result<_, T::Error>>()?;
318 let edges = self
319 .edges
320 .into_iter()
321 .map(try_convert_edge::<N, T>)
322 .collect::<Result<_, T::Error>>()?;
323 Ok(Graph { nodes, edges })
324 }
325}
326
327impl<const DIRECTED: bool, N, S> Default for Graph<DIRECTED, N, S>
329where
330 N: GraphNodeId,
331 S: BuildHasher + Default,
332{
333 fn default() -> Self {
334 Self::with_capacity(0, 0)
335 }
336}
337
338impl<N: GraphNodeId, S: BuildHasher> DiGraph<N, S> {
339 pub(crate) fn iter_sccs(&self) -> impl Iterator<Item = SmallVec<[N; 4]>> + '_ {
341 super::tarjan_scc::new_tarjan_scc(self)
342 }
343}
344
345#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq, Hash)]
347#[repr(u8)]
348pub enum Direction {
349 Outgoing = 0,
351 Incoming = 1,
353}
354
355impl Direction {
356 #[inline]
358 pub fn opposite(self) -> Self {
359 match self {
360 Self::Outgoing => Self::Incoming,
361 Self::Incoming => Self::Outgoing,
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use crate::schedule::{NodeId, SystemKey};
369
370 use super::*;
371 use alloc::vec;
372 use slotmap::SlotMap;
373
374 #[test]
378 fn node_order_preservation() {
379 use NodeId::System;
380
381 let mut slotmap = SlotMap::<SystemKey, ()>::with_key();
382 let mut graph = DiGraph::<NodeId>::default();
383
384 let sys1 = slotmap.insert(());
385 let sys2 = slotmap.insert(());
386 let sys3 = slotmap.insert(());
387 let sys4 = slotmap.insert(());
388
389 graph.add_node(System(sys1));
390 graph.add_node(System(sys2));
391 graph.add_node(System(sys3));
392 graph.add_node(System(sys4));
393
394 assert_eq!(
395 graph.nodes().collect::<Vec<_>>(),
396 vec![System(sys1), System(sys2), System(sys3), System(sys4)]
397 );
398
399 graph.remove_node(System(sys1));
400
401 assert_eq!(
402 graph.nodes().collect::<Vec<_>>(),
403 vec![System(sys4), System(sys2), System(sys3)]
404 );
405
406 graph.remove_node(System(sys4));
407
408 assert_eq!(
409 graph.nodes().collect::<Vec<_>>(),
410 vec![System(sys3), System(sys2)]
411 );
412
413 graph.remove_node(System(sys2));
414
415 assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![System(sys3)]);
416
417 graph.remove_node(System(sys3));
418
419 assert_eq!(graph.nodes().collect::<Vec<_>>(), vec![]);
420 }
421
422 #[test]
426 fn strongly_connected_components() {
427 use NodeId::System;
428
429 let mut slotmap = SlotMap::<SystemKey, ()>::with_key();
430 let mut graph = DiGraph::<NodeId>::default();
431
432 let sys1 = slotmap.insert(());
433 let sys2 = slotmap.insert(());
434 let sys3 = slotmap.insert(());
435 let sys4 = slotmap.insert(());
436 let sys5 = slotmap.insert(());
437 let sys6 = slotmap.insert(());
438
439 graph.add_edge(System(sys1), System(sys2));
440 graph.add_edge(System(sys2), System(sys1));
441
442 graph.add_edge(System(sys2), System(sys3));
443 graph.add_edge(System(sys3), System(sys2));
444
445 graph.add_edge(System(sys4), System(sys5));
446 graph.add_edge(System(sys5), System(sys4));
447
448 graph.add_edge(System(sys6), System(sys2));
449
450 let sccs = graph
451 .iter_sccs()
452 .map(|scc| scc.to_vec())
453 .collect::<Vec<_>>();
454
455 assert_eq!(
456 sccs,
457 vec![
458 vec![System(sys3), System(sys2), System(sys1)],
459 vec![System(sys5), System(sys4)],
460 vec![System(sys6)]
461 ]
462 );
463 }
464}