From f42f58e4c6236aed99b416b93a8f2e5c2901027b Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Wed, 28 Aug 2019 01:20:55 -0400 Subject: [PATCH] Re-inline ena --- Cargo.lock | 16 +- Cargo.toml | 3 +- src/ena/bitvec.rs | 301 +++++++++++++++++++ src/ena/mod.rs | 15 + src/ena/snapshot_vec.rs | 374 ++++++++++++++++++++++++ src/ena/unify/backing_vec.rs | 220 ++++++++++++++ src/ena/unify/mod.rs | 547 +++++++++++++++++++++++++++++++++++ src/ena/unify/tests.rs | 476 ++++++++++++++++++++++++++++++ src/lib.rs | 3 +- 9 files changed, 1939 insertions(+), 16 deletions(-) create mode 100644 src/ena/bitvec.rs create mode 100644 src/ena/mod.rs create mode 100644 src/ena/snapshot_vec.rs create mode 100644 src/ena/unify/backing_vec.rs create mode 100644 src/ena/unify/mod.rs create mode 100644 src/ena/unify/tests.rs diff --git a/Cargo.lock b/Cargo.lock index b52d52e3b4..0f7f62fea9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -55,14 +55,6 @@ name = "either" version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -[[package]] -name = "ena" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", -] - [[package]] name = "fixedbitset" version = "0.1.9" @@ -123,7 +115,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "log" -version = "0.4.6" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "cfg-if 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", @@ -263,12 +255,11 @@ version = "0.1.0" dependencies = [ "combine 3.8.1 (registry+https://github.com/rust-lang/crates.io-index)", "dogged 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", - "ena 0.13.0 (registry+https://github.com/rust-lang/crates.io-index)", "fraction 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", "fxhash 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "im-rc 13.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "indoc 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", - "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", + "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", "maplit 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", "num 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "petgraph 0.4.13 (registry+https://github.com/rust-lang/crates.io-index)", @@ -371,7 +362,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum difference 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "524cbf6897b527295dff137cec09ecf3a05f4fddffd7dfcd1585403449e74198" "checksum dogged 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2638df109789fe360f0d9998c5438dd19a36678aaf845e46f285b688b1a1657a" "checksum either 1.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "5527cfe0d098f36e3f8839852688e63c8fff1c90b2b405aef730615f9a7bcf7b" -"checksum ena 0.13.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3dc01d68e08ca384955a3aeba9217102ca1aa85b6e168639bf27739f1d749d87" "checksum fixedbitset 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "86d4de0081402f5e88cdac65c8dcdcc73118c1a7a465e2a05f0da05843a8ea33" "checksum fraction 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)" = "1055159ac82fb210c813303f716b6c8db57ace9d5ec2dbbc2e1d7a864c1dd74e" "checksum fxhash 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" @@ -379,7 +369,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum indoc 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1f59f228c76fda6ecd8dab79683039a7054c748587f682a911094f473647bd6" "checksum indoc-impl 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "63f070ef080db3601c1a0ecc75c7bb35104cc0ce2d7c4e049952a96a61d8933b" "checksum lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bc5729f27f159ddd61f4df6228e827e86643d4d3e7c32183cb30a1c08f604a14" -"checksum log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c84ec4b527950aa83a329754b01dbe3f58361d1c5efacd1f6d68c494d08a17c6" +"checksum log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)" = "14b6052be84e6b71ab17edffc2eeabf5c2c3ae1fdb464aae35ac50c67a44e1f7" "checksum maplit 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "08cbb6b4fef96b6d77bfc40ec491b1690c779e77b05cd9f07f787ed376fd4c43" "checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e" "checksum num 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "cf4825417e1e1406b3782a8ce92f4d53f26ec055e3622e1881ca8e9f5f9e08db" diff --git a/Cargo.toml b/Cargo.toml index 6ff78c65cf..3d5131c994 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,14 +5,13 @@ authors = ["Richard Feldman "] [dependencies] dogged = { version = "0.2.0", optional = true } -log = "0.4" +log = "0.4.8" petgraph = { version = "0.4.5", optional = true } combine = "3.8.1" im-rc = "13.0.0" fraction = "0.6.2" num = "0.2.0" fxhash = "0.2.1" -ena = "0.13.0" [dev-dependencies] pretty_assertions = "0.5.1" diff --git a/src/ena/bitvec.rs b/src/ena/bitvec.rs new file mode 100644 index 0000000000..3677c8c5e5 --- /dev/null +++ b/src/ena/bitvec.rs @@ -0,0 +1,301 @@ +// Copyright 2015 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +/// A very simple BitVector type. +pub struct BitVector { + data: Vec, +} + +impl BitVector { + pub fn new(num_bits: usize) -> BitVector { + let num_words = u64s(num_bits); + BitVector { data: vec![0; num_words] } + } + + pub fn contains(&self, bit: usize) -> bool { + let (word, mask) = word_mask(bit); + (self.data[word] & mask) != 0 + } + + /// Returns true if the bit has changed. + pub fn insert(&mut self, bit: usize) -> bool { + let (word, mask) = word_mask(bit); + let data = &mut self.data[word]; + let value = *data; + let new_value = value | mask; + *data = new_value; + new_value != value + } + + pub fn insert_all(&mut self, all: &BitVector) -> bool { + assert!(self.data.len() == all.data.len()); + let mut changed = false; + for (i, j) in self.data.iter_mut().zip(&all.data) { + let value = *i; + *i = value | *j; + if value != *i { + changed = true; + } + } + changed + } + + pub fn grow(&mut self, num_bits: usize) { + let num_words = u64s(num_bits); + let extra_words = self.data.len() - num_words; + self.data.extend((0..extra_words).map(|_| 0)); + } + + /// Iterates over indexes of set bits in a sorted order + pub fn iter<'a>(&'a self) -> BitVectorIter<'a> { + BitVectorIter { + iter: self.data.iter(), + current: 0, + idx: 0, + } + } +} + +pub struct BitVectorIter<'a> { + iter: ::std::slice::Iter<'a, u64>, + current: u64, + idx: usize, +} + +impl<'a> Iterator for BitVectorIter<'a> { + type Item = usize; + fn next(&mut self) -> Option { + while self.current == 0 { + self.current = if let Some(&i) = self.iter.next() { + if i == 0 { + self.idx += 64; + continue; + } else { + self.idx = u64s(self.idx) * 64; + i + } + } else { + return None; + } + } + let offset = self.current.trailing_zeros() as usize; + self.current >>= offset; + self.current >>= 1; // shift otherwise overflows for 0b1000_0000_…_0000 + self.idx += offset + 1; + return Some(self.idx - 1); + } +} + +/// A "bit matrix" is basically a square matrix of booleans +/// represented as one gigantic bitvector. In other words, it is as if +/// you have N bitvectors, each of length N. Note that `elements` here is `N`/ +#[derive(Clone)] +pub struct BitMatrix { + elements: usize, + vector: Vec, +} + +impl BitMatrix { + // Create a new `elements x elements` matrix, initially empty. + pub fn new(elements: usize) -> BitMatrix { + // For every element, we need one bit for every other + // element. Round up to an even number of u64s. + let u64s_per_elem = u64s(elements); + BitMatrix { + elements: elements, + vector: vec![0; elements * u64s_per_elem], + } + } + + /// The range of bits for a given element. + fn range(&self, element: usize) -> (usize, usize) { + let u64s_per_elem = u64s(self.elements); + let start = element * u64s_per_elem; + (start, start + u64s_per_elem) + } + + pub fn add(&mut self, source: usize, target: usize) -> bool { + let (start, _) = self.range(source); + let (word, mask) = word_mask(target); + let mut vector = &mut self.vector[..]; + let v1 = vector[start + word]; + let v2 = v1 | mask; + vector[start + word] = v2; + v1 != v2 + } + + /// Do the bits from `source` contain `target`? + /// + /// Put another way, if the matrix represents (transitive) + /// reachability, can `source` reach `target`? + pub fn contains(&self, source: usize, target: usize) -> bool { + let (start, _) = self.range(source); + let (word, mask) = word_mask(target); + (self.vector[start + word] & mask) != 0 + } + + /// Returns those indices that are reachable from both `a` and + /// `b`. This is an O(n) operation where `n` is the number of + /// elements (somewhat independent from the actual size of the + /// intersection, in particular). + pub fn intersection(&self, a: usize, b: usize) -> Vec { + let (a_start, a_end) = self.range(a); + let (b_start, b_end) = self.range(b); + let mut result = Vec::with_capacity(self.elements); + for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() { + let mut v = self.vector[i] & self.vector[j]; + for bit in 0..64 { + if v == 0 { + break; + } + if v & 0x1 != 0 { + result.push(base * 64 + bit); + } + v >>= 1; + } + } + result + } + + /// Add the bits from `read` to the bits from `write`, + /// return true if anything changed. + /// + /// This is used when computing transitive reachability because if + /// you have an edge `write -> read`, because in that case + /// `write` can reach everything that `read` can (and + /// potentially more). + pub fn merge(&mut self, read: usize, write: usize) -> bool { + let (read_start, read_end) = self.range(read); + let (write_start, write_end) = self.range(write); + let vector = &mut self.vector[..]; + let mut changed = false; + for (read_index, write_index) in (read_start..read_end).zip(write_start..write_end) { + let v1 = vector[write_index]; + let v2 = v1 | vector[read_index]; + vector[write_index] = v2; + changed = changed | (v1 != v2); + } + changed + } +} + +fn u64s(elements: usize) -> usize { + (elements + 63) / 64 +} + +fn word_mask(index: usize) -> (usize, u64) { + let word = index / 64; + let mask = 1 << (index % 64); + (word, mask) +} + +#[test] +fn bitvec_iter_works() { + let mut bitvec = BitVector::new(100); + bitvec.insert(1); + bitvec.insert(10); + bitvec.insert(19); + bitvec.insert(62); + bitvec.insert(63); + bitvec.insert(64); + bitvec.insert(65); + bitvec.insert(66); + bitvec.insert(99); + assert_eq!(bitvec.iter().collect::>(), + [1, 10, 19, 62, 63, 64, 65, 66, 99]); +} + +#[test] +fn bitvec_iter_works_2() { + let mut bitvec = BitVector::new(300); + bitvec.insert(1); + bitvec.insert(10); + bitvec.insert(19); + bitvec.insert(62); + bitvec.insert(66); + bitvec.insert(99); + bitvec.insert(299); + assert_eq!(bitvec.iter().collect::>(), + [1, 10, 19, 62, 66, 99, 299]); + +} + +#[test] +fn bitvec_iter_works_3() { + let mut bitvec = BitVector::new(319); + bitvec.insert(0); + bitvec.insert(127); + bitvec.insert(191); + bitvec.insert(255); + bitvec.insert(319); + assert_eq!(bitvec.iter().collect::>(), [0, 127, 191, 255, 319]); +} + +#[test] +fn union_two_vecs() { + let mut vec1 = BitVector::new(65); + let mut vec2 = BitVector::new(65); + assert!(vec1.insert(3)); + assert!(!vec1.insert(3)); + assert!(vec2.insert(5)); + assert!(vec2.insert(64)); + assert!(vec1.insert_all(&vec2)); + assert!(!vec1.insert_all(&vec2)); + assert!(vec1.contains(3)); + assert!(!vec1.contains(4)); + assert!(vec1.contains(5)); + assert!(!vec1.contains(63)); + assert!(vec1.contains(64)); +} + +#[test] +fn grow() { + let mut vec1 = BitVector::new(65); + assert!(vec1.insert(3)); + assert!(!vec1.insert(3)); + assert!(vec1.insert(5)); + assert!(vec1.insert(64)); + vec1.grow(128); + assert!(vec1.contains(3)); + assert!(vec1.contains(5)); + assert!(vec1.contains(64)); + assert!(!vec1.contains(126)); +} + +#[test] +fn matrix_intersection() { + let mut vec1 = BitMatrix::new(200); + + // (*) Elements reachable from both 2 and 65. + + vec1.add(2, 3); + vec1.add(2, 6); + vec1.add(2, 10); // (*) + vec1.add(2, 64); // (*) + vec1.add(2, 65); + vec1.add(2, 130); + vec1.add(2, 160); // (*) + + vec1.add(64, 133); + + vec1.add(65, 2); + vec1.add(65, 8); + vec1.add(65, 10); // (*) + vec1.add(65, 64); // (*) + vec1.add(65, 68); + vec1.add(65, 133); + vec1.add(65, 160); // (*) + + let intersection = vec1.intersection(2, 64); + assert!(intersection.is_empty()); + + let intersection = vec1.intersection(2, 65); + assert_eq!(intersection, &[10, 64, 160]); +} diff --git a/src/ena/mod.rs b/src/ena/mod.rs new file mode 100644 index 0000000000..36b9093a0c --- /dev/null +++ b/src/ena/mod.rs @@ -0,0 +1,15 @@ +// Copyright 2015 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! An implementation of union-find. See the `unify` module for more +//! details. + +pub mod snapshot_vec; +pub mod unify; diff --git a/src/ena/snapshot_vec.rs b/src/ena/snapshot_vec.rs new file mode 100644 index 0000000000..7bb68a2e56 --- /dev/null +++ b/src/ena/snapshot_vec.rs @@ -0,0 +1,374 @@ +// Copyright 2014 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! A utility class for implementing "snapshottable" things; a snapshottable data structure permits +//! you to take a snapshot (via `start_snapshot`) and then, after making some changes, elect either +//! to rollback to the start of the snapshot or commit those changes. +//! +//! This vector is intended to be used as part of an abstraction, not serve as a complete +//! abstraction on its own. As such, while it will roll back most changes on its own, it also +//! supports a `get_mut` operation that gives you an arbitrary mutable pointer into the vector. To +//! ensure that any changes you make this with this pointer are rolled back, you must invoke +//! `record` to record any changes you make and also supplying a delegate capable of reversing +//! those changes. + +use self::UndoLog::*; + +use std::fmt; +use std::mem; +use std::ops; + +#[derive(Debug)] +pub enum UndoLog { + /// New variable with given index was created. + NewElem(usize), + + /// Variable with given index was changed *from* the given value. + SetElem(usize, D::Value), + + /// Extensible set of actions + Other(D::Undo), +} + +pub struct SnapshotVec { + values: Vec, + undo_log: Vec>, + num_open_snapshots: usize, +} + +impl fmt::Debug for SnapshotVec + where D: SnapshotVecDelegate, + D: fmt::Debug, + D::Undo: fmt::Debug, + D::Value: fmt::Debug +{ + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("SnapshotVec") + .field("values", &self.values) + .field("undo_log", &self.undo_log) + .field("num_open_snapshots", &self.num_open_snapshots) + .finish() + } +} + +// Snapshots are tokens that should be created/consumed linearly. +pub struct Snapshot { + // Number of values at the time the snapshot was taken. + pub(crate) value_count: usize, + // Length of the undo log at the time the snapshot was taken. + undo_len: usize, +} + +pub trait SnapshotVecDelegate { + type Value; + type Undo; + + fn reverse(values: &mut Vec, action: Self::Undo); +} + +// HACK(eddyb) manual impl avoids `Default` bound on `D`. +impl Default for SnapshotVec { + fn default() -> Self { + SnapshotVec { + values: Vec::new(), + undo_log: Vec::new(), + num_open_snapshots: 0, + } + } +} + +impl SnapshotVec { + pub fn new() -> Self { + Self::default() + } + + pub fn with_capacity(c: usize) -> SnapshotVec { + SnapshotVec { + values: Vec::with_capacity(c), + undo_log: Vec::new(), + num_open_snapshots: 0, + } + } + + fn in_snapshot(&self) -> bool { + self.num_open_snapshots > 0 + } + + pub fn record(&mut self, action: D::Undo) { + if self.in_snapshot() { + self.undo_log.push(Other(action)); + } + } + + pub fn len(&self) -> usize { + self.values.len() + } + + pub fn push(&mut self, elem: D::Value) -> usize { + let len = self.values.len(); + self.values.push(elem); + + if self.in_snapshot() { + self.undo_log.push(NewElem(len)); + } + + len + } + + pub fn get(&self, index: usize) -> &D::Value { + &self.values[index] + } + + /// Reserve space for new values, just like an ordinary vec. + pub fn reserve(&mut self, additional: usize) { + // This is not affected by snapshots or anything. + self.values.reserve(additional); + } + + /// Returns a mutable pointer into the vec; whatever changes you make here cannot be undone + /// automatically, so you should be sure call `record()` with some sort of suitable undo + /// action. + pub fn get_mut(&mut self, index: usize) -> &mut D::Value { + &mut self.values[index] + } + + /// Updates the element at the given index. The old value will saved (and perhaps restored) if + /// a snapshot is active. + pub fn set(&mut self, index: usize, new_elem: D::Value) { + let old_elem = mem::replace(&mut self.values[index], new_elem); + if self.in_snapshot() { + self.undo_log.push(SetElem(index, old_elem)); + } + } + + /// Updates all elements. Potentially more efficient -- but + /// otherwise equivalent to -- invoking `set` for each element. + pub fn set_all(&mut self, mut new_elems: impl FnMut(usize) -> D::Value) { + if !self.in_snapshot() { + for (index, slot) in self.values.iter_mut().enumerate() { + *slot = new_elems(index); + } + } else { + for i in 0..self.values.len() { + self.set(i, new_elems(i)); + } + } + } + + pub fn update(&mut self, index: usize, op: OP) + where + OP: FnOnce(&mut D::Value), + D::Value: Clone, + { + if self.in_snapshot() { + let old_elem = self.values[index].clone(); + self.undo_log.push(SetElem(index, old_elem)); + } + op(&mut self.values[index]); + } + + pub fn start_snapshot(&mut self) -> Snapshot { + self.num_open_snapshots += 1; + Snapshot { + value_count: self.values.len(), + undo_len: self.undo_log.len(), + } + } + + pub fn actions_since_snapshot(&self, snapshot: &Snapshot) -> &[UndoLog] { + &self.undo_log[snapshot.undo_len..] + } + + fn assert_open_snapshot(&self, snapshot: &Snapshot) { + // Failures here may indicate a failure to follow a stack discipline. + assert!(self.undo_log.len() >= snapshot.undo_len); + assert!(self.num_open_snapshots > 0); + } + + pub fn rollback_to(&mut self, snapshot: Snapshot) { + debug!("rollback_to({})", snapshot.undo_len); + + self.assert_open_snapshot(&snapshot); + + while self.undo_log.len() > snapshot.undo_len { + match self.undo_log.pop().unwrap() { + NewElem(i) => { + self.values.pop(); + assert!(self.values.len() == i); + } + + SetElem(i, v) => { + self.values[i] = v; + } + + Other(u) => { + D::reverse(&mut self.values, u); + } + } + } + + self.num_open_snapshots -= 1; + } + + /// Commits all changes since the last snapshot. Of course, they + /// can still be undone if there is a snapshot further out. + pub fn commit(&mut self, snapshot: Snapshot) { + debug!("commit({})", snapshot.undo_len); + + self.assert_open_snapshot(&snapshot); + + if self.num_open_snapshots == 1 { + // The root snapshot. It's safe to clear the undo log because + // there's no snapshot further out that we might need to roll back + // to. + assert!(snapshot.undo_len == 0); + self.undo_log.clear(); + } + + self.num_open_snapshots -= 1; + } +} + +impl ops::Deref for SnapshotVec { + type Target = [D::Value]; + fn deref(&self) -> &[D::Value] { + &*self.values + } +} + +impl ops::DerefMut for SnapshotVec { + fn deref_mut(&mut self) -> &mut [D::Value] { + &mut *self.values + } +} + +impl ops::Index for SnapshotVec { + type Output = D::Value; + fn index(&self, index: usize) -> &D::Value { + self.get(index) + } +} + +impl ops::IndexMut for SnapshotVec { + fn index_mut(&mut self, index: usize) -> &mut D::Value { + self.get_mut(index) + } +} + +impl Extend for SnapshotVec { + fn extend(&mut self, iterable: T) + where + T: IntoIterator, + { + let initial_len = self.values.len(); + self.values.extend(iterable); + let final_len = self.values.len(); + + if self.in_snapshot() { + self.undo_log.extend((initial_len..final_len).map(|len| NewElem(len))); + } + } +} + +impl Clone for SnapshotVec +where + D::Value: Clone, + D::Undo: Clone, +{ + fn clone(&self) -> Self { + SnapshotVec { + values: self.values.clone(), + undo_log: self.undo_log.clone(), + num_open_snapshots: self.num_open_snapshots, + } + } +} + +impl Clone for UndoLog +where + D::Value: Clone, + D::Undo: Clone, +{ + fn clone(&self) -> Self { + match *self { + NewElem(i) => NewElem(i), + SetElem(i, ref v) => SetElem(i, v.clone()), + Other(ref u) => Other(u.clone()), + } + } +} + +impl SnapshotVecDelegate for i32 { + type Value = i32; + type Undo = (); + + fn reverse(_: &mut Vec, _: ()) {} +} + +#[test] +fn basic() { + let mut vec: SnapshotVec = SnapshotVec::default(); + assert!(!vec.in_snapshot()); + assert_eq!(vec.len(), 0); + vec.push(22); + vec.push(33); + assert_eq!(vec.len(), 2); + assert_eq!(*vec.get(0), 22); + assert_eq!(*vec.get(1), 33); + vec.set(1, 34); + assert_eq!(vec.len(), 2); + assert_eq!(*vec.get(0), 22); + assert_eq!(*vec.get(1), 34); + + let snapshot = vec.start_snapshot(); + assert!(vec.in_snapshot()); + + vec.push(44); + vec.push(55); + vec.set(1, 35); + assert_eq!(vec.len(), 4); + assert_eq!(*vec.get(0), 22); + assert_eq!(*vec.get(1), 35); + assert_eq!(*vec.get(2), 44); + assert_eq!(*vec.get(3), 55); + + vec.rollback_to(snapshot); + assert!(!vec.in_snapshot()); + + assert_eq!(vec.len(), 2); + assert_eq!(*vec.get(0), 22); + assert_eq!(*vec.get(1), 34); +} + +#[test] +#[should_panic] +fn out_of_order() { + let mut vec: SnapshotVec = SnapshotVec::default(); + vec.push(22); + let snapshot1 = vec.start_snapshot(); + vec.push(33); + let snapshot2 = vec.start_snapshot(); + vec.push(44); + vec.rollback_to(snapshot1); // bogus, but accepted + vec.rollback_to(snapshot2); // asserts +} + +#[test] +fn nested_commit_then_rollback() { + let mut vec: SnapshotVec = SnapshotVec::default(); + vec.push(22); + let snapshot1 = vec.start_snapshot(); + let snapshot2 = vec.start_snapshot(); + vec.set(0, 23); + vec.commit(snapshot2); + assert_eq!(*vec.get(0), 23); + vec.rollback_to(snapshot1); + assert_eq!(*vec.get(0), 22); +} diff --git a/src/ena/unify/backing_vec.rs b/src/ena/unify/backing_vec.rs new file mode 100644 index 0000000000..640d0b0068 --- /dev/null +++ b/src/ena/unify/backing_vec.rs @@ -0,0 +1,220 @@ +#[cfg(feature = "persistent")] +use dogged::DVec; +use ena::snapshot_vec as sv; +use std::ops::{self, Range}; +use std::marker::PhantomData; + +use super::{VarValue, UnifyKey, UnifyValue}; + +#[allow(dead_code)] // rustc BUG +#[allow(type_alias_bounds)] +type Key = ::Key; + +/// Largely internal trait implemented by the unification table +/// backing store types. The most common such type is `InPlace`, +/// which indicates a standard, mutable unification table. +pub trait UnificationStore: + ops::Index>> + Clone + Default +{ + type Key: UnifyKey; + type Value: UnifyValue; + type Snapshot; + + fn start_snapshot(&mut self) -> Self::Snapshot; + + fn rollback_to(&mut self, snapshot: Self::Snapshot); + + fn commit(&mut self, snapshot: Self::Snapshot); + + fn values_since_snapshot(&self, snapshot: &Self::Snapshot) -> Range; + + fn reset_unifications( + &mut self, + value: impl FnMut(u32) -> VarValue, + ); + + fn len(&self) -> usize; + + fn push(&mut self, value: VarValue); + + fn reserve(&mut self, num_new_values: usize); + + fn update(&mut self, index: usize, op: F) + where F: FnOnce(&mut VarValue); + + fn tag() -> &'static str { + Self::Key::tag() + } +} + +/// Backing store for an in-place unification table. +/// Not typically used directly. +#[derive(Clone, Debug)] +pub struct InPlace { + values: sv::SnapshotVec> +} + +// HACK(eddyb) manual impl avoids `Default` bound on `K`. +impl Default for InPlace { + fn default() -> Self { + InPlace { values: sv::SnapshotVec::new() } + } +} + +impl UnificationStore for InPlace { + type Key = K; + type Value = K::Value; + type Snapshot = sv::Snapshot; + + #[inline] + fn start_snapshot(&mut self) -> Self::Snapshot { + self.values.start_snapshot() + } + + #[inline] + fn rollback_to(&mut self, snapshot: Self::Snapshot) { + self.values.rollback_to(snapshot); + } + + #[inline] + fn commit(&mut self, snapshot: Self::Snapshot) { + self.values.commit(snapshot); + } + + #[inline] + fn values_since_snapshot(&self, snapshot: &Self::Snapshot) -> Range { + snapshot.value_count..self.len() + } + + #[inline] + fn reset_unifications( + &mut self, + mut value: impl FnMut(u32) -> VarValue, + ) { + self.values.set_all(|i| value(i as u32)); + } + + fn len(&self) -> usize { + self.values.len() + } + + #[inline] + fn push(&mut self, value: VarValue) { + self.values.push(value); + } + + #[inline] + fn reserve(&mut self, num_new_values: usize) { + self.values.reserve(num_new_values); + } + + #[inline] + fn update(&mut self, index: usize, op: F) + where F: FnOnce(&mut VarValue) + { + self.values.update(index, op) + } +} + +impl ops::Index for InPlace + where K: UnifyKey +{ + type Output = VarValue; + fn index(&self, index: usize) -> &VarValue { + &self.values[index] + } +} + +#[derive(Copy, Clone, Debug)] +struct Delegate(PhantomData); + +impl sv::SnapshotVecDelegate for Delegate { + type Value = VarValue; + type Undo = (); + + fn reverse(_: &mut Vec>, _: ()) {} +} + +#[cfg(feature = "persistent")] +#[derive(Clone, Debug)] +pub struct Persistent { + values: DVec> +} + +// HACK(eddyb) manual impl avoids `Default` bound on `K`. +#[cfg(feature = "persistent")] +impl Default for Persistent { + fn default() -> Self { + Persistent { values: DVec::new() } + } +} + +#[cfg(feature = "persistent")] +impl UnificationStore for Persistent { + type Key = K; + type Value = K::Value; + type Snapshot = Self; + + #[inline] + fn start_snapshot(&mut self) -> Self::Snapshot { + self.clone() + } + + #[inline] + fn rollback_to(&mut self, snapshot: Self::Snapshot) { + *self = snapshot; + } + + #[inline] + fn commit(&mut self, _snapshot: Self::Snapshot) {} + + #[inline] + fn values_since_snapshot(&self, snapshot: &Self::Snapshot) -> Range { + snapshot.len()..self.len() + } + + #[inline] + fn reset_unifications( + &mut self, + mut value: impl FnMut(u32) -> VarValue, + ) { + // Without extending dogged, there isn't obviously a more + // efficient way to do this. But it's pretty dumb. Maybe + // dogged needs a `map`. + for i in 0 .. self.values.len() { + self.values[i] = value(i as u32); + } + } + + fn len(&self) -> usize { + self.values.len() + } + + #[inline] + fn push(&mut self, value: VarValue) { + self.values.push(value); + } + + #[inline] + fn reserve(&mut self, _num_new_values: usize) { + // not obviously relevant to DVec. + } + + #[inline] + fn update(&mut self, index: usize, op: F) + where F: FnOnce(&mut VarValue) + { + let p = &mut self.values[index]; + op(p); + } +} + +#[cfg(feature = "persistent")] +impl ops::Index for Persistent + where K: UnifyKey +{ + type Output = VarValue; + fn index(&self, index: usize) -> &VarValue { + &self.values[index] + } +} diff --git a/src/ena/unify/mod.rs b/src/ena/unify/mod.rs new file mode 100644 index 0000000000..dbef13a5a8 --- /dev/null +++ b/src/ena/unify/mod.rs @@ -0,0 +1,547 @@ +// Copyright 2012-2014 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Union-find implementation. The main type is `UnificationTable`. +//! +//! You can define your own type for the *keys* in the table, but you +//! must implement `UnifyKey` for that type. The assumption is that +//! keys will be newtyped integers, hence we require that they +//! implement `Copy`. +//! +//! Keys can have values associated with them. The assumption is that +//! these values are cheaply cloneable (ideally, `Copy`), and some of +//! the interfaces are oriented around that assumption. If you just +//! want the classical "union-find" algorithm where you group things +//! into sets, use the `Value` type of `()`. +//! +//! When you have keys with non-trivial values, you must also define +//! how those values can be merged. As part of doing this, you can +//! define the "error" type to return on error; if errors are not +//! possible, use `NoError` (an uninstantiable struct). Using this +//! type also unlocks various more ergonomic methods (e.g., `union()` +//! in place of `unify_var_var()`). +//! +//! The best way to see how it is used is to read the `tests.rs` file; +//! search for e.g. `UnitKey`. + +use std::marker; +use std::fmt::Debug; +use std::ops::Range; + +mod backing_vec; +pub use self::backing_vec::{InPlace, UnificationStore}; + +#[cfg(feature = "persistent")] +pub use self::backing_vec::Persistent; + + +#[cfg(test)] +mod tests; + +/// This trait is implemented by any type that can serve as a type +/// variable. We call such variables *unification keys*. For example, +/// this trait is implemented by `IntVid`, which represents integral +/// variables. +/// +/// Each key type has an associated value type `V`. For example, for +/// `IntVid`, this is `Option`, representing some +/// (possibly not yet known) sort of integer. +/// +/// Clients are expected to provide implementations of this trait; you +/// can see some examples in the `test` module. +pub trait UnifyKey: Copy + Clone + Debug + PartialEq { + type Value: UnifyValue; + + fn index(&self) -> u32; + + fn from_index(u: u32) -> Self; + + fn tag() -> &'static str; + + /// If true, then `self` should be preferred as root to `other`. + /// Note that we assume a consistent partial ordering, so + /// returning true implies that `other.prefer_as_root_to(self)` + /// would return false. If there is no ordering between two keys + /// (i.e., `a.prefer_as_root_to(b)` and `b.prefer_as_root_to(a)` + /// both return false) then the rank will be used to determine the + /// root in an optimal way. + /// + /// NB. The only reason to implement this method is if you want to + /// control what value is returned from `find()`. In general, it + /// is better to let the unification table determine the root, + /// since overriding the rank can cause execution time to increase + /// dramatically. + #[allow(unused_variables)] + fn order_roots( + a: Self, + a_value: &Self::Value, + b: Self, + b_value: &Self::Value, + ) -> Option<(Self, Self)> { + None + } +} + +/// Trait implemented for **values** associated with a unification +/// key. This trait defines how to merge the values from two keys that +/// are unioned together. This merging can be fallible. If you attempt +/// to union two keys whose values cannot be merged, then the error is +/// propagated up and the two keys are not unioned. +/// +/// This crate provides implementations of `UnifyValue` for `()` +/// (which is infallible) and `Option` (where `T: UnifyValue`). The +/// option implementation merges two sum-values using the `UnifyValue` +/// implementation of `T`. +/// +/// See also `EqUnifyValue`, which is a convenience trait for cases +/// where the "merge" operation succeeds only if the two values are +/// equal. +pub trait UnifyValue: Clone + Debug { + /// Defines the type to return when merging of two values fails. + /// If merging is infallible, use the special struct `NoError` + /// found in this crate, which unlocks various more convenient + /// methods on the unification table. + type Error; + + /// Given two values, produce a new value that combines them. + /// If that is not possible, produce an error. + fn unify_values(value1: &Self, value2: &Self) -> Result; +} + +/// A convenient helper for unification values which must be equal or +/// else an error occurs. For example, if you are unifying types in a +/// simple functional language, this may be appropriate, since (e.g.) +/// you can't unify a type variable bound to `int` with one bound to +/// `float` (but you can unify two type variables both bound to +/// `int`). +/// +/// Any type which implements `EqUnifyValue` automatially implements +/// `UnifyValue`; if the two values are equal, merging is permitted. +/// Otherwise, the error `(v1, v2)` is returned, where `v1` and `v2` +/// are the two unequal values. +pub trait EqUnifyValue: Eq + Clone + Debug {} + +impl UnifyValue for T { + type Error = (T, T); + + fn unify_values(value1: &Self, value2: &Self) -> Result { + if value1 == value2 { + Ok(value1.clone()) + } else { + Err((value1.clone(), value2.clone())) + } + } +} + +/// A struct which can never be instantiated. Used +/// for the error type for infallible cases. +#[derive(Debug)] +pub struct NoError { + _dummy: (), +} + +/// Value of a unification key. We implement Tarjan's union-find +/// algorithm: when two keys are unified, one of them is converted +/// into a "redirect" pointing at the other. These redirects form a +/// DAG: the roots of the DAG (nodes that are not redirected) are each +/// associated with a value of type `V` and a rank. The rank is used +/// to keep the DAG relatively balanced, which helps keep the running +/// time of the algorithm under control. For more information, see +/// . +#[derive(PartialEq, Clone, Debug)] +pub struct VarValue { // FIXME pub + parent: K, // if equal to self, this is a root + value: K::Value, // value assigned (only relevant to root) + rank: u32, // max depth (only relevant to root) +} + +/// Table of unification keys and their values. You must define a key type K +/// that implements the `UnifyKey` trait. Unification tables can be used in two-modes: +/// +/// - in-place (`UnificationTable>` or `InPlaceUnificationTable`): +/// - This is the standard mutable mode, where the array is modified +/// in place. +/// - To do backtracking, you can employ the `snapshot` and `rollback_to` +/// methods. +/// - persistent (`UnificationTable>` or `PersistentUnificationTable`): +/// - In this mode, we use a persistent vector to store the data, so that +/// cloning the table is an O(1) operation. +/// - This implies that ordinary operations are quite a bit slower though. +/// - Requires the `persistent` feature be selected in your Cargo.toml file. +#[derive(Clone, Debug, Default)] +pub struct UnificationTable { + /// Indicates the current value of each key. + values: S, +} + +/// A unification table that uses an "in-place" vector. +#[allow(type_alias_bounds)] +pub type InPlaceUnificationTable = UnificationTable>; + +/// A unification table that uses a "persistent" vector. +#[cfg(feature = "persistent")] +#[allow(type_alias_bounds)] +pub type PersistentUnificationTable = UnificationTable>; + +/// At any time, users may snapshot a unification table. The changes +/// made during the snapshot may either be *committed* or *rolled back*. +pub struct Snapshot { + // Link snapshot to the unification store `S` of the table. + marker: marker::PhantomData, + snapshot: S::Snapshot, +} + +impl VarValue { + fn new_var(key: K, value: K::Value) -> VarValue { + VarValue::new(key, value, 0) + } + + fn new(parent: K, value: K::Value, rank: u32) -> VarValue { + VarValue { + parent: parent, // this is a root + value: value, + rank: rank, + } + } + + fn redirect(&mut self, to: K) { + self.parent = to; + } + + fn root(&mut self, rank: u32, value: K::Value) { + self.rank = rank; + self.value = value; + } + + fn parent(&self, self_key: K) -> Option { + self.if_not_self(self.parent, self_key) + } + + fn if_not_self(&self, key: K, self_key: K) -> Option { + if key == self_key { + None + } else { + Some(key) + } + } +} + +// We can't use V:LatticeValue, much as I would like to, +// because frequently the pattern is that V=Option for some +// other type parameter U, and we have no way to say +// Option:LatticeValue. + +impl UnificationTable { + pub fn new() -> Self { + Self::default() + } + + /// Starts a new snapshot. Each snapshot must be either + /// rolled back or committed in a "LIFO" (stack) order. + pub fn snapshot(&mut self) -> Snapshot { + Snapshot { + marker: marker::PhantomData::, + snapshot: self.values.start_snapshot(), + } + } + + /// Reverses all changes since the last snapshot. Also + /// removes any keys that have been created since then. + pub fn rollback_to(&mut self, snapshot: Snapshot) { + debug!("{}: rollback_to()", S::tag()); + self.values.rollback_to(snapshot.snapshot); + } + + /// Commits all changes since the last snapshot. Of course, they + /// can still be undone if there is a snapshot further out. + pub fn commit(&mut self, snapshot: Snapshot) { + debug!("{}: commit()", S::tag()); + self.values.commit(snapshot.snapshot); + } + + /// Creates a fresh key with the given value. + pub fn new_key(&mut self, value: S::Value) -> S::Key { + let len = self.values.len(); + let key: S::Key = UnifyKey::from_index(len as u32); + self.values.push(VarValue::new_var(key, value)); + debug!("{}: created new key: {:?}", S::tag(), key); + key + } + + /// Reserve memory for `num_new_keys` to be created. Does not + /// actually create the new keys; you must then invoke `new_key`. + pub fn reserve(&mut self, num_new_keys: usize) { + self.values.reserve(num_new_keys); + } + + /// Clears all unifications that have been performed, resetting to + /// the initial state. The values of each variable are given by + /// the closure. + pub fn reset_unifications( + &mut self, + mut value: impl FnMut(S::Key) -> S::Value, + ) { + self.values.reset_unifications(|i| { + let key = UnifyKey::from_index(i as u32); + let value = value(key); + VarValue::new_var(key, value) + }); + } + + /// Returns the number of keys created so far. + pub fn len(&self) -> usize { + self.values.len() + } + + /// Returns the keys of all variables created since the `snapshot`. + pub fn vars_since_snapshot( + &self, + snapshot: &Snapshot, + ) -> Range { + let range = self.values.values_since_snapshot(&snapshot.snapshot); + S::Key::from_index(range.start as u32)..S::Key::from_index(range.end as u32) + } + + /// Obtains the current value for a particular key. + /// Not for end-users; they can use `probe_value`. + fn value(&self, key: S::Key) -> &VarValue { + &self.values[key.index() as usize] + } + + /// Find the root node for `vid`. This uses the standard + /// union-find algorithm with path compression: + /// . + /// + /// NB. This is a building-block operation and you would probably + /// prefer to call `probe` below. + fn get_root_key(&mut self, vid: S::Key) -> S::Key { + let redirect = { + match self.value(vid).parent(vid) { + None => return vid, + Some(redirect) => redirect, + } + }; + + let root_key: S::Key = self.get_root_key(redirect); + if root_key != redirect { + // Path compression + self.update_value(vid, |value| value.parent = root_key); + } + + root_key + } + + fn update_value(&mut self, key: S::Key, op: OP) + where + OP: FnOnce(&mut VarValue), + { + self.values.update(key.index() as usize, op); + debug!("Updated variable {:?} to {:?}", key, self.value(key)); + } + + /// Either redirects `node_a` to `node_b` or vice versa, depending + /// on the relative rank. The value associated with the new root + /// will be `new_value`. + /// + /// NB: This is the "union" operation of "union-find". It is + /// really more of a building block. If the values associated with + /// your key are non-trivial, you would probably prefer to call + /// `unify_var_var` below. + fn unify_roots(&mut self, key_a: S::Key, key_b: S::Key, new_value: S::Value) { + debug!("unify(key_a={:?}, key_b={:?})", key_a, key_b); + + let rank_a = self.value(key_a).rank; + let rank_b = self.value(key_b).rank; + if let Some((new_root, redirected)) = + S::Key::order_roots( + key_a, + &self.value(key_a).value, + key_b, + &self.value(key_b).value, + ) { + // compute the new rank for the new root that they chose; + // this may not be the optimal choice. + let new_rank = if new_root == key_a { + debug_assert!(redirected == key_b); + if rank_a > rank_b { + rank_a + } else { + rank_b + 1 + } + } else { + debug_assert!(new_root == key_b); + debug_assert!(redirected == key_a); + if rank_b > rank_a { + rank_b + } else { + rank_a + 1 + } + }; + self.redirect_root(new_rank, redirected, new_root, new_value); + } else if rank_a > rank_b { + // a has greater rank, so a should become b's parent, + // i.e., b should redirect to a. + self.redirect_root(rank_a, key_b, key_a, new_value); + } else if rank_a < rank_b { + // b has greater rank, so a should redirect to b. + self.redirect_root(rank_b, key_a, key_b, new_value); + } else { + // If equal, redirect one to the other and increment the + // other's rank. + self.redirect_root(rank_a + 1, key_a, key_b, new_value); + } + } + + /// Internal method to redirect `old_root_key` (which is currently + /// a root) to a child of `new_root_key` (which will remain a + /// root). The rank and value of `new_root_key` will be updated to + /// `new_rank` and `new_value` respectively. + fn redirect_root( + &mut self, + new_rank: u32, + old_root_key: S::Key, + new_root_key: S::Key, + new_value: S::Value, + ) { + self.update_value(old_root_key, |old_root_value| { + old_root_value.redirect(new_root_key); + }); + self.update_value(new_root_key, |new_root_value| { + new_root_value.root(new_rank, new_value); + }); + } +} + +/// //////////////////////////////////////////////////////////////////////// +/// Public API + +impl<'tcx, S, K, V> UnificationTable +where + S: UnificationStore, + K: UnifyKey, + V: UnifyValue, +{ + /// Unions two keys without the possibility of failure; only + /// applicable when unify values use `NoError` as their error + /// type. + pub fn union(&mut self, a_id: K1, b_id: K2) + where + K1: Into, + K2: Into, + V: UnifyValue, + { + self.unify_var_var(a_id, b_id).unwrap(); + } + + /// Unions a key and a value without the possibility of failure; + /// only applicable when unify values use `NoError` as their error + /// type. + pub fn union_value(&mut self, id: K1, value: V) + where + K1: Into, + V: UnifyValue, + { + self.unify_var_value(id, value).unwrap(); + } + + /// Given two keys, indicates whether they have been unioned together. + pub fn unioned(&mut self, a_id: K1, b_id: K2) -> bool + where + K1: Into, + K2: Into, + { + self.find(a_id) == self.find(b_id) + } + + /// Given a key, returns the (current) root key. + pub fn find(&mut self, id: K1) -> K + where + K1: Into, + { + let id = id.into(); + self.get_root_key(id) + } + + /// Unions together two variables, merging their values. If + /// merging the values fails, the error is propagated and this + /// method has no effect. + pub fn unify_var_var(&mut self, a_id: K1, b_id: K2) -> Result<(), V::Error> + where + K1: Into, + K2: Into, + { + let a_id = a_id.into(); + let b_id = b_id.into(); + + let root_a = self.get_root_key(a_id); + let root_b = self.get_root_key(b_id); + + if root_a == root_b { + return Ok(()); + } + + let combined = V::unify_values(&self.value(root_a).value, &self.value(root_b).value)?; + + Ok(self.unify_roots(root_a, root_b, combined)) + } + + /// Sets the value of the key `a_id` to `b`, attempting to merge + /// with the previous value. + pub fn unify_var_value(&mut self, a_id: K1, b: V) -> Result<(), V::Error> + where + K1: Into, + { + let a_id = a_id.into(); + let root_a = self.get_root_key(a_id); + let value = V::unify_values(&self.value(root_a).value, &b)?; + self.update_value(root_a, |node| node.value = value); + Ok(()) + } + + /// Returns the current value for the given key. If the key has + /// been union'd, this will give the value from the current root. + pub fn probe_value(&mut self, id: K1) -> V + where + K1: Into, + { + let id = id.into(); + let id = self.get_root_key(id); + self.value(id).value.clone() + } +} + + +/////////////////////////////////////////////////////////////////////////// + +impl UnifyValue for () { + type Error = NoError; + + fn unify_values(_: &(), _: &()) -> Result<(), NoError> { + Ok(()) + } +} + +impl UnifyValue for Option { + type Error = V::Error; + + fn unify_values(a: &Option, b: &Option) -> Result { + match (a, b) { + (&None, &None) => Ok(None), + (&Some(ref v), &None) | + (&None, &Some(ref v)) => Ok(Some(v.clone())), + (&Some(ref a), &Some(ref b)) => { + match V::unify_values(a, b) { + Ok(v) => Ok(Some(v)), + Err(err) => Err(err), + } + } + } + } +} diff --git a/src/ena/unify/tests.rs b/src/ena/unify/tests.rs new file mode 100644 index 0000000000..8910ad473b --- /dev/null +++ b/src/ena/unify/tests.rs @@ -0,0 +1,476 @@ +// Copyright 2015 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Naming the benchmarks using uppercase letters helps them sort +// better. +#![allow(non_snake_case)] + +#[cfg(feature = "bench")] +extern crate test; +#[cfg(feature = "bench")] +use self::test::Bencher; +use std::cmp; +use ena::unify::{NoError, InPlace, InPlaceUnificationTable, UnifyKey, EqUnifyValue, UnifyValue}; +use ena::unify::{UnificationStore, UnificationTable}; +#[cfg(feature = "persistent")] +use ena::unify::Persistent; + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +struct UnitKey(u32); + +impl UnifyKey for UnitKey { + type Value = (); + fn index(&self) -> u32 { + self.0 + } + fn from_index(u: u32) -> UnitKey { + UnitKey(u) + } + fn tag() -> &'static str { + "UnitKey" + } +} + +macro_rules! all_modes { + ($name:ident for $t:ty => $body:tt) => { + fn test_body<$name: UnificationStore::Value>>() { + $body + } + + test_body::>(); + + #[cfg(feature = "persistent")] + test_body::>(); + } +} + +#[test] +fn basic() { + all_modes! { + S for UnitKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(()); + let k2 = ut.new_key(()); + assert_eq!(ut.unioned(k1, k2), false); + ut.union(k1, k2); + assert_eq!(ut.unioned(k1, k2), true); + } + } +} + +#[test] +fn big_array() { + all_modes! { + S for UnitKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let mut keys = Vec::new(); + const MAX: usize = 1 << 15; + + for _ in 0..MAX { + keys.push(ut.new_key(())); + } + + for i in 1..MAX { + let l = keys[i - 1]; + let r = keys[i]; + ut.union(l, r); + } + + for i in 0..MAX { + assert!(ut.unioned(keys[0], keys[i])); + } + } + } +} + +#[cfg(feature = "bench")] +fn big_array_bench_generic>(b: &mut Bencher) { + let mut ut: UnificationTable = UnificationTable::new(); + let mut keys = Vec::new(); + const MAX: usize = 1 << 15; + + for _ in 0..MAX { + keys.push(ut.new_key(())); + } + + b.iter(|| { + for i in 1..MAX { + let l = keys[i - 1]; + let r = keys[i]; + ut.union(l, r); + } + + for i in 0..MAX { + assert!(ut.unioned(keys[0], keys[i])); + } + }) +} + +#[cfg(feature = "bench")] +#[bench] +fn big_array_bench_InPlace(b: &mut Bencher) { + big_array_bench_generic::>(b); +} + +#[cfg(all(feature = "bench", feature = "persistent"))] +#[bench] +fn big_array_bench_Persistent(b: &mut Bencher) { + big_array_bench_generic::>(b); +} + +#[cfg(feature = "bench")] +fn big_array_bench_in_snapshot_generic>(b: &mut Bencher) { + let mut ut: UnificationTable = UnificationTable::new(); + let mut keys = Vec::new(); + const MAX: usize = 1 << 15; + + for _ in 0..MAX { + keys.push(ut.new_key(())); + } + + b.iter(|| { + let snapshot = ut.snapshot(); + + for i in 1..MAX { + let l = keys[i - 1]; + let r = keys[i]; + ut.union(l, r); + } + + for i in 0..MAX { + assert!(ut.unioned(keys[0], keys[i])); + } + + ut.rollback_to(snapshot); + }) +} + +#[cfg(feature = "bench")] +#[bench] +fn big_array_bench_in_snapshot_InPlace(b: &mut Bencher) { + big_array_bench_in_snapshot_generic::>(b); +} + +#[cfg(all(feature = "bench", feature = "persistent"))] +#[bench] +fn big_array_bench_in_snapshot_Persistent(b: &mut Bencher) { + big_array_bench_in_snapshot_generic::>(b); +} + +#[cfg(feature = "bench")] +fn big_array_bench_clone_generic>(b: &mut Bencher) { + let mut ut: UnificationTable = UnificationTable::new(); + let mut keys = Vec::new(); + const MAX: usize = 1 << 15; + + for _ in 0..MAX { + keys.push(ut.new_key(())); + } + + b.iter(|| { + let saved_table = ut.clone(); + + for i in 1..MAX { + let l = keys[i - 1]; + let r = keys[i]; + ut.union(l, r); + } + + for i in 0..MAX { + assert!(ut.unioned(keys[0], keys[i])); + } + + ut = saved_table; + }) +} + +#[cfg(feature = "bench")] +#[bench] +fn big_array_bench_clone_InPlace(b: &mut Bencher) { + big_array_bench_clone_generic::>(b); +} + +#[cfg(all(feature = "bench", feature = "persistent"))] +#[bench] +fn big_array_bench_clone_Persistent(b: &mut Bencher) { + big_array_bench_clone_generic::>(b); +} + +#[test] +fn even_odd() { + all_modes! { + S for UnitKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let mut keys = Vec::new(); + const MAX: usize = 1 << 10; + + for i in 0..MAX { + let key = ut.new_key(()); + keys.push(key); + + if i >= 2 { + ut.union(key, keys[i - 2]); + } + } + + for i in 1..MAX { + assert!(!ut.unioned(keys[i - 1], keys[i])); + } + + for i in 2..MAX { + assert!(ut.unioned(keys[i - 2], keys[i])); + } + } + } +} + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +struct IntKey(u32); + +impl UnifyKey for IntKey { + type Value = Option; + fn index(&self) -> u32 { + self.0 + } + fn from_index(u: u32) -> IntKey { + IntKey(u) + } + fn tag() -> &'static str { + "IntKey" + } +} + +impl EqUnifyValue for i32 {} + +#[test] +fn unify_same_int_twice() { + all_modes! { + S for IntKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(None); + let k2 = ut.new_key(None); + assert!(ut.unify_var_value(k1, Some(22)).is_ok()); + assert!(ut.unify_var_value(k2, Some(22)).is_ok()); + assert!(ut.unify_var_var(k1, k2).is_ok()); + assert_eq!(ut.probe_value(k1), Some(22)); + } + } +} + +#[test] +fn unify_vars_then_int_indirect() { + all_modes! { + S for IntKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(None); + let k2 = ut.new_key(None); + assert!(ut.unify_var_var(k1, k2).is_ok()); + assert!(ut.unify_var_value(k1, Some(22)).is_ok()); + assert_eq!(ut.probe_value(k2), Some(22)); + } + } +} + +#[test] +fn unify_vars_different_ints_1() { + all_modes! { + S for IntKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(None); + let k2 = ut.new_key(None); + assert!(ut.unify_var_var(k1, k2).is_ok()); + assert!(ut.unify_var_value(k1, Some(22)).is_ok()); + assert!(ut.unify_var_value(k2, Some(23)).is_err()); + } + } +} + +#[test] +fn unify_vars_different_ints_2() { + all_modes! { + S for IntKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(None); + let k2 = ut.new_key(None); + assert!(ut.unify_var_var(k2, k1).is_ok()); + assert!(ut.unify_var_value(k1, Some(22)).is_ok()); + assert!(ut.unify_var_value(k2, Some(23)).is_err()); + } + } +} + +#[test] +fn unify_distinct_ints_then_vars() { + all_modes! { + S for IntKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(None); + let k2 = ut.new_key(None); + assert!(ut.unify_var_value(k1, Some(22)).is_ok()); + assert!(ut.unify_var_value(k2, Some(23)).is_ok()); + assert!(ut.unify_var_var(k2, k1).is_err()); + } + } +} + +#[test] +fn unify_root_value_1() { + all_modes! { + S for IntKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(None); + let k2 = ut.new_key(None); + let k3 = ut.new_key(None); + assert!(ut.unify_var_value(k1, Some(22)).is_ok()); + assert!(ut.unify_var_var(k1, k2).is_ok()); + assert!(ut.unify_var_value(k3, Some(23)).is_ok()); + assert!(ut.unify_var_var(k1, k3).is_err()); + } + } +} + +#[test] +fn unify_root_value_2() { + all_modes! { + S for IntKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(None); + let k2 = ut.new_key(None); + let k3 = ut.new_key(None); + assert!(ut.unify_var_value(k1, Some(22)).is_ok()); + assert!(ut.unify_var_var(k2, k1).is_ok()); + assert!(ut.unify_var_value(k3, Some(23)).is_ok()); + assert!(ut.unify_var_var(k1, k3).is_err()); + } + } +} + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +struct OrderedKey(u32); + +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +struct OrderedRank(u32); + +impl UnifyKey for OrderedKey { + type Value = OrderedRank; + fn index(&self) -> u32 { + self.0 + } + fn from_index(u: u32) -> OrderedKey { + OrderedKey(u) + } + fn tag() -> &'static str { + "OrderedKey" + } + fn order_roots( + a: OrderedKey, + a_rank: &OrderedRank, + b: OrderedKey, + b_rank: &OrderedRank, + ) -> Option<(OrderedKey, OrderedKey)> { + println!("{:?} vs {:?}", a_rank, b_rank); + if a_rank > b_rank { + Some((a, b)) + } else if b_rank > a_rank { + Some((b, a)) + } else { + None + } + } +} + +impl UnifyValue for OrderedRank { + type Error = NoError; + + fn unify_values(value1: &Self, value2: &Self) -> Result { + Ok(OrderedRank(cmp::max(value1.0, value2.0))) + } +} + +#[test] +fn ordered_key() { + all_modes! { + S for OrderedKey => { + let mut ut: UnificationTable = UnificationTable::new(); + + let k0_1 = ut.new_key(OrderedRank(0)); + let k0_2 = ut.new_key(OrderedRank(0)); + let k0_3 = ut.new_key(OrderedRank(0)); + let k0_4 = ut.new_key(OrderedRank(0)); + + ut.union(k0_1, k0_2); // rank of one of those will now be 1 + ut.union(k0_3, k0_4); // rank of new root also 1 + ut.union(k0_1, k0_3); // rank of new root now 2 + + let k0_5 = ut.new_key(OrderedRank(0)); + let k0_6 = ut.new_key(OrderedRank(0)); + ut.union(k0_5, k0_6); // rank of new root now 1 + + ut.union(k0_1, k0_5); // new root rank 2, should not be k0_5 or k0_6 + assert!(vec![k0_1, k0_2, k0_3, k0_4].contains(&ut.find(k0_1))); + } + } +} + +#[test] +fn ordered_key_k1() { + all_modes! { + S for UnitKey => { + let mut ut: InPlaceUnificationTable = UnificationTable::new(); + + let k0_1 = ut.new_key(OrderedRank(0)); + let k0_2 = ut.new_key(OrderedRank(0)); + let k0_3 = ut.new_key(OrderedRank(0)); + let k0_4 = ut.new_key(OrderedRank(0)); + + ut.union(k0_1, k0_2); // rank of one of those will now be 1 + ut.union(k0_3, k0_4); // rank of new root also 1 + ut.union(k0_1, k0_3); // rank of new root now 2 + + let k1_5 = ut.new_key(OrderedRank(1)); + let k1_6 = ut.new_key(OrderedRank(1)); + ut.union(k1_5, k1_6); // rank of new root now 1 + + ut.union(k0_1, k1_5); // even though k1 has lower rank, it wins + assert!( + vec![k1_5, k1_6].contains(&ut.find(k0_1)), + "unexpected choice for root: {:?}", + ut.find(k0_1) + ); + } + } +} + +/// Test that we *can* clone. +#[test] +fn clone_table() { + all_modes! { + S for IntKey => { + let mut ut: UnificationTable = UnificationTable::new(); + let k1 = ut.new_key(None); + let k2 = ut.new_key(None); + let k3 = ut.new_key(None); + assert!(ut.unify_var_value(k1, Some(22)).is_ok()); + assert!(ut.unify_var_value(k2, Some(22)).is_ok()); + assert!(ut.unify_var_var(k1, k2).is_ok()); + assert_eq!(ut.probe_value(k3), None); + + let mut ut1 = ut.clone(); + assert_eq!(ut1.probe_value(k1), Some(22)); + assert_eq!(ut1.probe_value(k3), None); + + assert!(ut.unify_var_value(k3, Some(44)).is_ok()); + + assert_eq!(ut1.probe_value(k1), Some(22)); + assert_eq!(ut1.probe_value(k3), None); + assert_eq!(ut.probe_value(k3), Some(44)); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 9994247b2e..c9fe73ebd5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,11 +17,12 @@ pub mod solve; pub mod unify; pub mod infer; pub mod pretty_print_types; +pub mod ena; extern crate im_rc; extern crate fraction; extern crate num; extern crate fxhash; -extern crate ena; #[macro_use] extern crate combine; +#[macro_use] extern crate log;