mirror of
https://github.com/erg-lang/erg.git
synced 2025-07-07 21:25:31 +00:00
191 lines
5.1 KiB
Rust
191 lines
5.1 KiB
Rust
//! Topological sort
|
|
use crate::dict::Dict;
|
|
use crate::set::Set;
|
|
use crate::traits::Immutable;
|
|
|
|
use std::fmt::Debug;
|
|
use std::hash::Hash;
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub enum TopoSortErrorKind {
|
|
CyclicReference,
|
|
KeyNotFound,
|
|
}
|
|
|
|
impl std::fmt::Display for TopoSortErrorKind {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{self:?}")
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for TopoSortErrorKind {}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct TopoSortError {
|
|
pub kind: TopoSortErrorKind,
|
|
pub msg: String,
|
|
}
|
|
|
|
impl std::fmt::Display for TopoSortError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}: {}", self.kind, self.msg)
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for TopoSortError {}
|
|
|
|
impl TopoSortError {
|
|
pub fn new(kind: TopoSortErrorKind, msg: String) -> Self {
|
|
Self { kind, msg }
|
|
}
|
|
|
|
pub fn key_not_found(msg: String) -> Self {
|
|
Self::new(TopoSortErrorKind::KeyNotFound, msg)
|
|
}
|
|
|
|
pub fn cycle_detected(msg: String) -> Self {
|
|
Self::new(TopoSortErrorKind::CyclicReference, msg)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct Node<T: Eq + Hash + Immutable, U> {
|
|
pub id: T,
|
|
pub data: U,
|
|
pub depends_on: Set<T>,
|
|
}
|
|
|
|
impl<T: Eq + Hash + Immutable, U> Node<T, U> {
|
|
pub const fn new(id: T, data: U, depends_on: Set<T>) -> Self {
|
|
Node {
|
|
id,
|
|
data,
|
|
depends_on,
|
|
}
|
|
}
|
|
|
|
pub fn push_dep(&mut self, dep: T) {
|
|
self.depends_on.insert(dep);
|
|
}
|
|
|
|
pub fn pop_dep(&mut self, dep: &T) -> bool {
|
|
self.depends_on.remove(dep)
|
|
}
|
|
|
|
pub fn depends_on(&self, dep: &T) -> bool {
|
|
self.depends_on.contains(dep)
|
|
}
|
|
}
|
|
|
|
pub type Graph<T, U> = Vec<Node<T, U>>;
|
|
|
|
fn _reorder_by_idx<T>(mut v: Vec<T>, idx: Vec<usize>) -> Vec<T> {
|
|
let mut swap_table = Dict::new();
|
|
for (node_id, mut sort_i) in idx.into_iter().enumerate() {
|
|
if node_id == sort_i {
|
|
continue;
|
|
}
|
|
while let Some(moved_to) = swap_table.get(&sort_i) {
|
|
sort_i = *moved_to;
|
|
}
|
|
v.swap(node_id, sort_i);
|
|
swap_table.insert(node_id, sort_i);
|
|
}
|
|
v
|
|
}
|
|
|
|
fn reorder_by_key<T: Eq + Hash + Immutable, U>(mut g: Graph<T, U>, idx: Vec<T>) -> Graph<T, U> {
|
|
g.sort_by_key(|node| idx.iter().position(|k| k == &node.id).unwrap());
|
|
g
|
|
}
|
|
|
|
fn dfs<T: Eq + Hash + Clone + Debug + Immutable, U: Debug>(
|
|
g: &Graph<T, U>,
|
|
v: T,
|
|
used: &mut Set<T>,
|
|
idx: &mut Vec<T>,
|
|
) -> Result<(), TopoSortError> {
|
|
used.insert(v.clone());
|
|
let Some(vertex) = g.iter().find(|n| n.id == v) else {
|
|
return Err(TopoSortError::key_not_found(format!("{g:?}: {v:?}")));
|
|
};
|
|
for node_id in vertex.depends_on.iter() {
|
|
// detecting cycles
|
|
if used.contains(node_id) && !idx.contains(node_id) {
|
|
return Err(TopoSortError::cycle_detected(format!(
|
|
"{v:?} -> {node_id:?}"
|
|
)));
|
|
}
|
|
if !used.contains(node_id) {
|
|
dfs(g, node_id.clone(), used, idx)?;
|
|
}
|
|
}
|
|
idx.push(v);
|
|
Ok(())
|
|
}
|
|
|
|
/// perform topological sort on a graph
|
|
#[allow(clippy::result_unit_err)]
|
|
pub fn tsort<T: Eq + Hash + Clone + Debug + Immutable, U: Debug>(
|
|
g: Graph<T, U>,
|
|
) -> Result<Graph<T, U>, TopoSortError> {
|
|
let n = g.len();
|
|
let mut idx = Vec::with_capacity(n);
|
|
let mut used = Set::new();
|
|
for v in g.iter() {
|
|
if !used.contains(&v.id) {
|
|
dfs(&g, v.id.clone(), &mut used, &mut idx)?;
|
|
}
|
|
}
|
|
Ok(reorder_by_key(g, idx))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::set;
|
|
|
|
#[test]
|
|
fn test_tsort() -> Result<(), TopoSortError> {
|
|
let v = vec!["e", "d", "b", "a", "c"];
|
|
let idx = vec![3, 2, 4, 1, 0];
|
|
assert_eq!(vec!["a", "b", "c", "d", "e"], _reorder_by_idx(v, idx));
|
|
|
|
// this is invalid, cause a cyclic reference exists
|
|
// ```
|
|
// odd 0 = False
|
|
// odd n = even n - 1
|
|
// even 0 = True
|
|
// even n = odd n - 1
|
|
// ```
|
|
let even = Node::new("even n", (), set!["odd n", "True"]);
|
|
let odd = Node::new("odd n", (), set!["even n", "False"]);
|
|
let tru = Node::new("True", (), set![]);
|
|
let fls = Node::new("False", (), set![]);
|
|
let dag = vec![even, odd, tru.clone(), fls.clone()];
|
|
assert!(tsort(dag).is_err());
|
|
|
|
// this is valid, cause declaration exists
|
|
// ```
|
|
// odd: Nat -> Bool
|
|
// odd 0 = False
|
|
// odd n = even n - 1
|
|
// even 0 = True
|
|
// even n = odd n - 1 # this refers the declaration, not the definition
|
|
// ```
|
|
let even = Node::new("even n", (), set!["odd n (decl)", "True"]);
|
|
let odd = Node::new("odd n", (), set!["even n", "False"]);
|
|
let odd_decl = Node::new("odd n (decl)", (), set![]);
|
|
let dag = vec![
|
|
even,
|
|
odd.clone(),
|
|
odd_decl.clone(),
|
|
fls.clone(),
|
|
tru.clone(),
|
|
];
|
|
let sorted = tsort(dag)?;
|
|
assert!(sorted[0] == odd_decl || sorted[0] == fls || sorted[0] == tru);
|
|
assert_eq!(sorted[4], odd);
|
|
Ok(())
|
|
}
|
|
}
|