mirror of
https://github.com/GraphiteEditor/Graphite.git
synced 2025-07-07 15:55:00 +00:00
Add math-parser library (#2033)
* start of parser * ops forgot * reorder files and work on executer * start of parser * ops forgot * reorder files and work on executer * Cleanup and fix tests * Integrate into the editor * added unit checking at parse time * fix tests * fix issues * fix editor intergration * update pest grammer to support units * units should be working, need to set up tests to know * make unit type store exponants as i32 * remove scale, insted just multiply the literal by the scale * unit now contains empty unit,remove options * add more tests and implement almost all unary operators * add evaluation context and variables * function calling, api might be refined later * add constants, change function call to not be as built into the parser and add tests * add function definitions * remove meval * remove raw-rs from workspace * add support for numberless units * fix unit handleing logic, add some "unit" tests(haha) * make it so units cant do implcit mul with idents * add bench and better tests * fix editor api * remove old test * change hashmap context to use deref * change constants to use hashmap instad of function --------- Co-authored-by: hypercube <0hypercube@gmail.com> Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
parent
51ce51ea8c
commit
9fb494764c
14 changed files with 1260 additions and 94 deletions
75
Cargo.lock
generated
75
Cargo.lock
generated
|
@ -2534,7 +2534,7 @@ dependencies = [
|
|||
"graphite-editor",
|
||||
"js-sys",
|
||||
"log",
|
||||
"meval",
|
||||
"math-parser",
|
||||
"ron",
|
||||
"serde",
|
||||
"serde-wasm-bindgen",
|
||||
|
@ -3534,6 +3534,19 @@ version = "0.7.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "math-parser"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"criterion",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"num-complex",
|
||||
"pest",
|
||||
"pest_derive",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.9"
|
||||
|
@ -3592,15 +3605,6 @@ dependencies = [
|
|||
"paste",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "meval"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/Titaniumtown/meval-rs#6bf579fd402928745cf4f24e5c975bece3285179"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
|
@ -4341,6 +4345,51 @@ version = "2.3.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
||||
|
||||
[[package]]
|
||||
name = "pest"
|
||||
version = "2.7.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdbef9d1d47087a895abd220ed25eb4ad973a5e26f6a4367b038c25e28dfc2d9"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"thiserror",
|
||||
"ucd-trie",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pest_derive"
|
||||
version = "2.7.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4d3a6e3394ec80feb3b6393c725571754c6188490265c61aaf260810d6b95aa0"
|
||||
dependencies = [
|
||||
"pest",
|
||||
"pest_generator",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pest_generator"
|
||||
version = "2.7.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94429506bde1ca69d1b5601962c73f4172ab4726571a59ea95931218cb0e930e"
|
||||
dependencies = [
|
||||
"pest",
|
||||
"pest_meta",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pest_meta"
|
||||
version = "2.7.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac8a071862e93690b6e34e9a5fb8e33ff3734473ac0245b27232222c4906a33f"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"pest",
|
||||
"sha2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "petgraph"
|
||||
version = "0.6.5"
|
||||
|
@ -6758,6 +6807,12 @@ version = "1.17.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||
|
||||
[[package]]
|
||||
name = "ucd-trie"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971"
|
||||
|
||||
[[package]]
|
||||
name = "uds_windows"
|
||||
version = "1.1.0"
|
||||
|
|
|
@ -18,6 +18,7 @@ members = [
|
|||
"libraries/dyn-any",
|
||||
"libraries/path-bool",
|
||||
"libraries/bezier-rs",
|
||||
"libraries/math-parser",
|
||||
"website/other/bezier-rs-demos/wasm",
|
||||
]
|
||||
exclude = ["node-graph/gpu-compiler"]
|
||||
|
@ -31,6 +32,7 @@ graph-craft = { path = "node-graph/graph-craft", features = ["serde"] }
|
|||
wgpu-executor = { path = "node-graph/wgpu-executor" }
|
||||
bezier-rs = { path = "libraries/bezier-rs", features = ["dyn-any"] }
|
||||
path-bool = { path = "libraries/path-bool", default-features = false }
|
||||
math-parser = { path = "libraries/math-parser" }
|
||||
node-macro = { path = "node-graph/node-macro" }
|
||||
|
||||
# Workspace dependencies
|
||||
|
@ -77,7 +79,6 @@ glam = { version = "0.28", default-features = false, features = ["serde"] }
|
|||
base64 = "0.22"
|
||||
image = { version = "0.25", default-features = false, features = ["png"] }
|
||||
rustybuzz = "0.17"
|
||||
meval = "0.2"
|
||||
spirv = "0.3"
|
||||
fern = { version = "0.6", features = ["colored"] }
|
||||
num_enum = "0.7"
|
||||
|
@ -94,9 +95,6 @@ syn = { version = "2.0", default-features = false, features = [
|
|||
] }
|
||||
kurbo = { version = "0.11.0", features = ["serde"] }
|
||||
|
||||
[patch.crates-io]
|
||||
meval = { git = "https://github.com/Titaniumtown/meval-rs" }
|
||||
|
||||
[profile.dev]
|
||||
opt-level = 1
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ js-sys = { workspace = true }
|
|||
wasm-bindgen-futures = { workspace = true }
|
||||
bezier-rs = { workspace = true }
|
||||
glam = { workspace = true }
|
||||
meval = { workspace = true }
|
||||
math-parser = { workspace = true }
|
||||
wgpu = { workspace = true, features = [
|
||||
"fragile-send-sync-non-atomic-wasm",
|
||||
] } # We don't have wgpu on multiple threads (yet) https://github.com/gfx-rs/wgpu/blob/trunk/CHANGELOG.md#wgpu-types-now-send-sync-on-wasm
|
||||
|
|
|
@ -910,72 +910,17 @@ impl EditorHandle {
|
|||
|
||||
#[wasm_bindgen(js_name = evaluateMathExpression)]
|
||||
pub fn evaluate_math_expression(expression: &str) -> Option<f64> {
|
||||
// TODO: Rewrite our own purpose-built math expression parser that supports unit conversions.
|
||||
|
||||
let mut context = meval::Context::new();
|
||||
context.var("tau", std::f64::consts::TAU);
|
||||
context.func("log", f64::log10);
|
||||
context.func("log10", f64::log10);
|
||||
context.func("log2", f64::log2);
|
||||
|
||||
// Insert asterisks where implicit multiplication is used in the expression string
|
||||
let expression = implicit_multiplication_preprocess(expression);
|
||||
|
||||
meval::eval_str_with_context(expression, &context).ok()
|
||||
}
|
||||
|
||||
// Modified from this public domain snippet: <https://gist.github.com/Titaniumtown/c181be5d06505e003d8c4d1e372684ff>
|
||||
// Discussion: <https://github.com/rekka/meval-rs/issues/28#issuecomment-1826381922>
|
||||
pub fn implicit_multiplication_preprocess(expression: &str) -> String {
|
||||
let function = expression.to_lowercase().replace("log10(", "log(").replace("log2(", "logtwo(").replace("pi", "π").replace("tau", "τ");
|
||||
let valid_variables: Vec<char> = "eπτ".chars().collect();
|
||||
let letters: Vec<char> = ('a'..='z').chain('A'..='Z').collect();
|
||||
let numbers: Vec<char> = ('0'..='9').collect();
|
||||
let function_chars: Vec<char> = function.chars().collect();
|
||||
let mut output_string: String = String::new();
|
||||
let mut prev_chars: Vec<char> = Vec::new();
|
||||
|
||||
for c in function_chars {
|
||||
let mut add_asterisk: bool = false;
|
||||
let prev_chars_len = prev_chars.len();
|
||||
|
||||
let prev_prev_char = if prev_chars_len >= 2 { *prev_chars.get(prev_chars_len - 2).unwrap() } else { ' ' };
|
||||
|
||||
let prev_char = if prev_chars_len >= 1 { *prev_chars.get(prev_chars_len - 1).unwrap() } else { ' ' };
|
||||
|
||||
let c_letters_var = letters.contains(&c) | valid_variables.contains(&c);
|
||||
let prev_letters_var = valid_variables.contains(&prev_char) | letters.contains(&prev_char);
|
||||
|
||||
if prev_char == ')' {
|
||||
if (c == '(') | numbers.contains(&c) | c_letters_var {
|
||||
add_asterisk = true;
|
||||
}
|
||||
} else if c == '(' {
|
||||
if (valid_variables.contains(&prev_char) | (')' == prev_char) | numbers.contains(&prev_char)) && !letters.contains(&prev_prev_char) {
|
||||
add_asterisk = true;
|
||||
}
|
||||
} else if numbers.contains(&prev_char) {
|
||||
if (c == '(') | c_letters_var {
|
||||
add_asterisk = true;
|
||||
}
|
||||
} else if letters.contains(&c) {
|
||||
if numbers.contains(&prev_char) | (valid_variables.contains(&prev_char) && valid_variables.contains(&c)) {
|
||||
add_asterisk = true;
|
||||
}
|
||||
} else if (numbers.contains(&c) | c_letters_var) && prev_letters_var {
|
||||
add_asterisk = true;
|
||||
}
|
||||
|
||||
if add_asterisk {
|
||||
output_string += "*";
|
||||
}
|
||||
|
||||
prev_chars.push(c);
|
||||
output_string += &c.to_string();
|
||||
}
|
||||
|
||||
// We have to convert the Greek symbols back to ASCII because meval doesn't support unicode symbols as context constants
|
||||
output_string.replace("logtwo(", "log2(").replace('π', "pi").replace('τ', "tau")
|
||||
let value = math_parser::evaluate(expression)
|
||||
.inspect_err(|err| error!("Math parser error on \"{expression}\": {err}"))
|
||||
.ok()?
|
||||
.0
|
||||
.inspect_err(|err| error!("Math evaluate error on \"{expression}\": {err} "))
|
||||
.ok()?;
|
||||
let Some(real) = value.as_real() else {
|
||||
error!("{value} was not a real; skipping.");
|
||||
return None;
|
||||
};
|
||||
Some(real)
|
||||
}
|
||||
|
||||
/// Helper function for calling JS's `requestAnimationFrame` with the given closure
|
||||
|
@ -1066,16 +1011,3 @@ fn auto_save_all_documents() {
|
|||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn implicit_multiplication_preprocess_tests() {
|
||||
assert_eq!(implicit_multiplication_preprocess("2pi"), "2*pi");
|
||||
assert_eq!(implicit_multiplication_preprocess("sin(2pi)"), "sin(2*pi)");
|
||||
assert_eq!(implicit_multiplication_preprocess("2sin(pi)"), "2*sin(pi)");
|
||||
assert_eq!(implicit_multiplication_preprocess("2sin(3(4 + 5))"), "2*sin(3*(4 + 5))");
|
||||
assert_eq!(implicit_multiplication_preprocess("3abs(-4)"), "3*abs(-4)");
|
||||
assert_eq!(implicit_multiplication_preprocess("-1(4)"), "-1*(4)");
|
||||
assert_eq!(implicit_multiplication_preprocess("(-1)4"), "(-1)*4");
|
||||
assert_eq!(implicit_multiplication_preprocess("(((-1)))(4)"), "(((-1)))*(4)");
|
||||
assert_eq!(implicit_multiplication_preprocess("2sin(pi) + 2cos(tau)"), "2*sin(pi) + 2*cos(tau)");
|
||||
}
|
||||
|
|
23
libraries/math-parser/Cargo.toml
Normal file
23
libraries/math-parser/Cargo.toml
Normal file
|
@ -0,0 +1,23 @@
|
|||
[package]
|
||||
name = "math-parser"
|
||||
version = "0.0.0"
|
||||
rust-version = "1.79"
|
||||
edition = "2021"
|
||||
authors = ["Graphite Authors <contact@graphite.rs>"]
|
||||
description = "Parser for Graphite style mathematics expressions"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
pest = "2.7"
|
||||
pest_derive = "2.7.11"
|
||||
thiserror = "1"
|
||||
lazy_static = "1.5"
|
||||
num-complex = "0.4"
|
||||
log = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5"
|
||||
|
||||
[[bench]]
|
||||
name = "bench"
|
||||
harness = false
|
50
libraries/math-parser/benches/bench.rs
Normal file
50
libraries/math-parser/benches/bench.rs
Normal file
|
@ -0,0 +1,50 @@
|
|||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
use math_parser::ast;
|
||||
use math_parser::context::EvalContext;
|
||||
|
||||
macro_rules! generate_benchmarks {
|
||||
($( $input:expr ),* $(,)?) => {
|
||||
fn parsing_bench(c: &mut Criterion) {
|
||||
$(
|
||||
c.bench_function(concat!("parse ", $input), |b| {
|
||||
b.iter(|| {
|
||||
let _ = black_box(ast::Node::from_str($input)).unwrap();
|
||||
});
|
||||
});
|
||||
)*
|
||||
}
|
||||
|
||||
fn evaluation_bench(c: &mut Criterion) {
|
||||
$(
|
||||
let expr = ast::Node::from_str($input).unwrap().0;
|
||||
let context = EvalContext::default();
|
||||
|
||||
c.bench_function(concat!("eval ", $input), |b| {
|
||||
b.iter(|| {
|
||||
let _ = black_box(expr.eval(&context));
|
||||
});
|
||||
});
|
||||
)*
|
||||
}
|
||||
|
||||
criterion_group!(benches, parsing_bench, evaluation_bench);
|
||||
criterion_main!(benches);
|
||||
};
|
||||
}
|
||||
|
||||
generate_benchmarks! {
|
||||
"(3 * (4 + sqrt(25)) - cos(pi/3) * (2^3)) + 5 * e", // Mixed nested functions, constants, and operations
|
||||
"((5 + 2 * (3 - sqrt(49)))^2) / (1 + sqrt(16)) + tau / 2", // Complex nested expression with constants
|
||||
"log(100, 10) + (5 * sin(pi/4) + sqrt(81)) / (2 * phi)", // Logarithmic and trigonometric functions
|
||||
"(sqrt(144) * 2 + 5) / (3 * (4 - sin(pi / 6))) + e^2", // Combined square root, trigonometric, and exponential operations
|
||||
"cos(2 * pi) + tan(pi / 3) * log(32, 2) - sqrt(256)", // Multiple trigonometric and logarithmic functions
|
||||
"(10 * (3 + 2) - 8 / 2)^2 + 7 * (2^4) - sqrt(225) + phi", // Mixed arithmetic with constants
|
||||
"(5^2 + 3^3) * (sqrt(81) + sqrt(64)) - tau * log(1000, 10)", // Power and square root with constants
|
||||
"((8 * sqrt(49) - 2 * e) + log(256, 2) / (2 + cos(pi))) * 1.5", // Nested functions and constants
|
||||
"(tan(pi / 4) + 5) * (3 + sqrt(36)) / (log(1024, 2) - 4)", // Nested functions with trigonometry and logarithm
|
||||
"((3 * e + 2 * sqrt(100)) - cos(tau / 4)) * log(27, 3) + phi", // Mixed constant usage and functions
|
||||
"(sqrt(100) + 5 * sin(pi / 6) - 8 / log(64, 2)) + e^(1.5)", // Complex mix of square root, division, and exponentiation
|
||||
"((sin(pi/2) + cos(0)) * (e^2 - 2 * sqrt(16))) / (log(100, 10) + pi)", // Nested trigonometric, exponential, and logarithmic functions
|
||||
"(5 * (7 + sqrt(121)) - (log(243, 3) * phi)) + 3^5 / tau", //
|
||||
}
|
75
libraries/math-parser/src/ast.rs
Normal file
75
libraries/math-parser/src/ast.rs
Normal file
|
@ -0,0 +1,75 @@
|
|||
use crate::value::Complex;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct Unit {
|
||||
// Exponent of length unit (meters)
|
||||
pub length: i32,
|
||||
// Exponent of mass unit (kilograms)
|
||||
pub mass: i32,
|
||||
// Exponent of time unit (seconds)
|
||||
pub time: i32,
|
||||
}
|
||||
|
||||
impl Default for Unit {
|
||||
fn default() -> Self {
|
||||
Self::BASE_UNIT
|
||||
}
|
||||
}
|
||||
|
||||
impl Unit {
|
||||
pub const BASE_UNIT: Unit = Unit { length: 0, mass: 0, time: 0 };
|
||||
|
||||
pub const LENGTH: Unit = Unit { length: 1, mass: 0, time: 0 };
|
||||
pub const MASS: Unit = Unit { length: 0, mass: 1, time: 0 };
|
||||
pub const TIME: Unit = Unit { length: 0, mass: 0, time: 1 };
|
||||
|
||||
pub const VELOCITY: Unit = Unit { length: 1, mass: 0, time: -1 };
|
||||
pub const ACCELERATION: Unit = Unit { length: 1, mass: 0, time: -2 };
|
||||
|
||||
pub const FORCE: Unit = Unit { length: 1, mass: 1, time: -2 };
|
||||
|
||||
pub fn base_unit() -> Self {
|
||||
Self::BASE_UNIT
|
||||
}
|
||||
|
||||
pub fn is_base(&self) -> bool {
|
||||
*self == Self::BASE_UNIT
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Literal {
|
||||
Float(f64),
|
||||
Complex(Complex),
|
||||
}
|
||||
|
||||
impl From<f64> for Literal {
|
||||
fn from(value: f64) -> Self {
|
||||
Self::Float(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum BinaryOp {
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
Div,
|
||||
Pow,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub enum UnaryOp {
|
||||
Neg,
|
||||
Sqrt,
|
||||
Fac,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Node {
|
||||
Lit(Literal),
|
||||
Var(String),
|
||||
FnCall { name: String, expr: Vec<Node> },
|
||||
BinOp { lhs: Box<Node>, op: BinaryOp, rhs: Box<Node> },
|
||||
UnaryOp { expr: Box<Node>, op: UnaryOp },
|
||||
}
|
121
libraries/math-parser/src/constants.rs
Normal file
121
libraries/math-parser/src/constants.rs
Normal file
|
@ -0,0 +1,121 @@
|
|||
use std::{collections::HashMap, f64::consts::PI};
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use num_complex::{Complex, ComplexFloat};
|
||||
|
||||
use crate::value::{Number, Value};
|
||||
lazy_static! {
|
||||
pub static ref DEFAULT_FUNCTIONS: HashMap<&'static str, Box<dyn Fn(&[Value]) -> Option<Value> + Send + Sync>> = {
|
||||
let mut map: HashMap<&'static str, Box<dyn Fn(&[Value]) -> Option<Value> + Send + Sync>> = HashMap::new();
|
||||
|
||||
map.insert(
|
||||
"sin",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sin()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sin()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"cos",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.cos()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.cos()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"tan",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.tan()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.tan()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"csc",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.sin().recip()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.sin().recip()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"sec",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.cos().recip()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.cos().recip()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"cot",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.tan().recip()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.tan().recip()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"invsin",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.asin()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.asin()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"invcos",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.acos()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.acos()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"invtan",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.atan()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.atan()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"invcsc",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.recip().asin()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.recip().asin()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"invsec",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real(real.recip().acos()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex(complex.recip().acos()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map.insert(
|
||||
"invcot",
|
||||
Box::new(|values| match values {
|
||||
[Value::Number(Number::Real(real))] => Some(Value::Number(Number::Real((PI / 2.0 - real).atan()))),
|
||||
[Value::Number(Number::Complex(complex))] => Some(Value::Number(Number::Complex((Complex::new(PI / 2.0, 0.0) - complex).atan()))),
|
||||
_ => None,
|
||||
}),
|
||||
);
|
||||
|
||||
map
|
||||
};
|
||||
}
|
83
libraries/math-parser/src/context.rs
Normal file
83
libraries/math-parser/src/context.rs
Normal file
|
@ -0,0 +1,83 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
ops::{Deref, DerefMut},
|
||||
};
|
||||
|
||||
use crate::value::Value;
|
||||
|
||||
//TODO: editor integration, implement these traits for whatever is needed, maybe merge them if needed
|
||||
pub trait ValueProvider {
|
||||
fn get_value(&self, name: &str) -> Option<Value>;
|
||||
}
|
||||
|
||||
pub trait FunctionProvider {
|
||||
fn run_function(&self, name: &str, args: &[Value]) -> Option<Value>;
|
||||
}
|
||||
|
||||
pub struct ValueMap(HashMap<String, Value>);
|
||||
|
||||
pub struct NothingMap;
|
||||
|
||||
impl ValueProvider for &ValueMap {
|
||||
fn get_value(&self, name: &str) -> Option<Value> {
|
||||
self.0.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueProvider for NothingMap {
|
||||
fn get_value(&self, _: &str) -> Option<Value> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueProvider for ValueMap {
|
||||
fn get_value(&self, name: &str) -> Option<Value> {
|
||||
self.0.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ValueMap {
|
||||
type Target = HashMap<String, Value>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl DerefMut for ValueMap {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl FunctionProvider for NothingMap {
|
||||
fn run_function(&self, _: &str, _: &[Value]) -> Option<Value> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EvalContext<V: ValueProvider, F: FunctionProvider> {
|
||||
values: V,
|
||||
functions: F,
|
||||
}
|
||||
|
||||
impl Default for EvalContext<NothingMap, NothingMap> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
values: NothingMap,
|
||||
functions: NothingMap,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: ValueProvider, F: FunctionProvider> EvalContext<V, F> {
|
||||
pub fn new(values: V, functions: F) -> Self {
|
||||
Self { values, functions }
|
||||
}
|
||||
|
||||
pub fn get_value(&self, name: &str) -> Option<Value> {
|
||||
self.values.get_value(name)
|
||||
}
|
||||
|
||||
pub fn run_function(&self, name: &str, args: &[Value]) -> Option<Value> {
|
||||
self.functions.run_function(name, args)
|
||||
}
|
||||
}
|
105
libraries/math-parser/src/executer.rs
Normal file
105
libraries/math-parser/src/executer.rs
Normal file
|
@ -0,0 +1,105 @@
|
|||
use thiserror::Error;
|
||||
|
||||
use crate::{
|
||||
ast::{Literal, Node},
|
||||
constants::DEFAULT_FUNCTIONS,
|
||||
context::{EvalContext, FunctionProvider, ValueProvider},
|
||||
value::{Number, Value},
|
||||
};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum EvalError {
|
||||
#[error("Missing value: {0}")]
|
||||
MissingValue(String),
|
||||
|
||||
#[error("Missing function: {0}")]
|
||||
MissingFunction(String),
|
||||
#[error("Wrong type for function call")]
|
||||
TypeError,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub fn eval<V: ValueProvider, F: FunctionProvider>(&self, context: &EvalContext<V, F>) -> Result<Value, EvalError> {
|
||||
match self {
|
||||
Node::Lit(lit) => match lit {
|
||||
Literal::Float(num) => Ok(Value::from_f64(*num)),
|
||||
Literal::Complex(num) => Ok(Value::Number(Number::Complex(*num))),
|
||||
},
|
||||
|
||||
Node::BinOp { lhs, op, rhs } => match (lhs.eval(context)?, rhs.eval(context)?) {
|
||||
(Value::Number(lhs), Value::Number(rhs)) => Ok(Value::Number(lhs.binary_op(*op, rhs))),
|
||||
},
|
||||
Node::UnaryOp { expr, op } => match expr.eval(context)? {
|
||||
Value::Number(num) => Ok(Value::Number(num.unary_op(*op))),
|
||||
},
|
||||
Node::Var(name) => context.get_value(name).ok_or_else(|| EvalError::MissingValue(name.clone())),
|
||||
Node::FnCall { name, expr } => {
|
||||
let values = expr.iter().map(|expr| expr.eval(context)).collect::<Result<Vec<Value>, EvalError>>()?;
|
||||
if let Some(function) = DEFAULT_FUNCTIONS.get(&name.as_str()) {
|
||||
function(&values).ok_or(EvalError::TypeError)
|
||||
} else if let Some(val) = context.run_function(name, &values) {
|
||||
Ok(val)
|
||||
} else {
|
||||
context.get_value(name).ok_or_else(|| EvalError::MissingFunction(name.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
ast::{BinaryOp, Literal, Node, UnaryOp},
|
||||
context::{EvalContext, ValueMap},
|
||||
value::Value,
|
||||
};
|
||||
|
||||
macro_rules! eval_tests {
|
||||
($($name:ident: $expected:expr => $expr:expr),* $(,)?) => {
|
||||
$(
|
||||
#[test]
|
||||
fn $name() {
|
||||
let result = $expr.eval(&EvalContext::default()).unwrap();
|
||||
assert_eq!(result, $expected);
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
eval_tests! {
|
||||
test_addition: Value::from_f64(7.0) => Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(3.0))),
|
||||
op: BinaryOp::Add,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(4.0))),
|
||||
},
|
||||
test_subtraction: Value::from_f64(1.0) => Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(5.0))),
|
||||
op: BinaryOp::Sub,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(4.0))),
|
||||
},
|
||||
test_multiplication: Value::from_f64(12.0) => Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(3.0))),
|
||||
op: BinaryOp::Mul,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(4.0))),
|
||||
},
|
||||
test_division: Value::from_f64(2.5) => Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(5.0))),
|
||||
op: BinaryOp::Div,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(2.0))),
|
||||
},
|
||||
test_negation: Value::from_f64(-3.0) => Node::UnaryOp {
|
||||
expr: Box::new(Node::Lit(Literal::Float(3.0))),
|
||||
op: UnaryOp::Neg,
|
||||
},
|
||||
test_sqrt: Value::from_f64(2.0) => Node::UnaryOp {
|
||||
expr: Box::new(Node::Lit(Literal::Float(4.0))),
|
||||
op: UnaryOp::Sqrt,
|
||||
},
|
||||
test_power: Value::from_f64(8.0) => Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(2.0))),
|
||||
op: BinaryOp::Pow,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(3.0))),
|
||||
},
|
||||
}
|
||||
}
|
60
libraries/math-parser/src/grammer.pest
Normal file
60
libraries/math-parser/src/grammer.pest
Normal file
|
@ -0,0 +1,60 @@
|
|||
WHITESPACE = _{ " " | "\t" }
|
||||
|
||||
// TODO: Proper indentation and formatting
|
||||
program = _{ SOI ~ expr ~ EOI }
|
||||
|
||||
expr = { atom ~ (infix ~ atom)* }
|
||||
atom = _{ prefix? ~ primary ~ postfix? }
|
||||
infix = _{ add | sub | mul | div | pow | paren }
|
||||
add = { "+" } // Addition
|
||||
sub = { "-" } // Subtraction
|
||||
mul = { "*" } // Multiplication
|
||||
div = { "/" } // Division
|
||||
mod = { "%" } // Modulo
|
||||
pow = { "^" } // Exponentiation
|
||||
paren = { "" } // Implicit multiplication operator
|
||||
|
||||
prefix = _{ neg | sqrt }
|
||||
neg = { "-" } // Negation
|
||||
sqrt = { "sqrt" }
|
||||
|
||||
postfix = _{ fac }
|
||||
fac = { "!" } // Factorial
|
||||
|
||||
primary = _{ ("(" ~ expr ~ ")") | lit | constant | fn_call | ident }
|
||||
fn_call = { ident ~ "(" ~ expr ~ ("," ~ expr)* ~ ")" }
|
||||
ident = @{ (ASCII_ALPHA | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
|
||||
lit = { unit | ((float | int) ~ unit?) }
|
||||
|
||||
float = @{ int ~ "." ~ int? ~ exp? | int ~ exp }
|
||||
exp = _{ ^"e" ~ ("+" | "-")? ~ int }
|
||||
int = @{ ASCII_DIGIT+ }
|
||||
|
||||
unit = ${ (scale ~ base_unit) | base_unit ~ !ident}
|
||||
base_unit = _{ meter | second | gram }
|
||||
meter = { "m" }
|
||||
second = { "s" }
|
||||
gram = { "g" }
|
||||
|
||||
scale = _{ nano | micro | milli | centi | deci | deca | hecto | kilo | mega | giga | tera }
|
||||
nano = { "n" }
|
||||
micro = { "µ" | "u" }
|
||||
milli = { "m" }
|
||||
centi = { "c" }
|
||||
deci = { "d" }
|
||||
deca = { "da" }
|
||||
hecto = { "h" }
|
||||
kilo = { "k" }
|
||||
mega = { "M" }
|
||||
giga = { "G" }
|
||||
tera = { "T" }
|
||||
|
||||
// Constants
|
||||
constant = { infinity | imaginary_unit | pi | tau | euler_number | golden_ratio | gravity_acceleration }
|
||||
infinity = { "inf" | "INF" | "infinity" | "INFINITY" | "∞" }
|
||||
imaginary_unit = { "i" | "I" }
|
||||
pi = { "pi" | "PI" | "π" }
|
||||
tau = { "tau" | "TAU" | "τ" }
|
||||
euler_number = { "e" }
|
||||
golden_ratio = { "phi" | "PHI" | "φ" }
|
||||
gravity_acceleration = { "G" }
|
151
libraries/math-parser/src/lib.rs
Normal file
151
libraries/math-parser/src/lib.rs
Normal file
|
@ -0,0 +1,151 @@
|
|||
#![allow(unused)]
|
||||
|
||||
pub mod ast;
|
||||
mod constants;
|
||||
pub mod context;
|
||||
pub mod executer;
|
||||
pub mod parser;
|
||||
pub mod value;
|
||||
|
||||
use ast::Unit;
|
||||
use context::{EvalContext, ValueMap};
|
||||
use executer::EvalError;
|
||||
use parser::ParseError;
|
||||
use value::Value;
|
||||
|
||||
pub fn evaluate(expression: &str) -> Result<(Result<Value, EvalError>, Unit), ParseError> {
|
||||
let expr = ast::Node::from_str(expression);
|
||||
let context = EvalContext::default();
|
||||
expr.map(|(node, unit)| (node.eval(&context), unit))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use value::Number;
|
||||
|
||||
use ast::Unit;
|
||||
|
||||
use super::*;
|
||||
|
||||
const EPSILON: f64 = 1e10_f64;
|
||||
|
||||
macro_rules! test_end_to_end{
|
||||
($($name:ident: $input:expr => ($expected_value:expr, $expected_unit:expr)),* $(,)?) => {
|
||||
$(
|
||||
#[test]
|
||||
fn $name() {
|
||||
let expected_value = $expected_value;
|
||||
let expected_unit = $expected_unit;
|
||||
|
||||
let expr = ast::Node::from_str($input);
|
||||
let context = EvalContext::default();
|
||||
|
||||
let (actual_value, actual_unit) = expr.map(|(node, unit)| (node.eval(&context), unit)).unwrap();
|
||||
let actual_value = actual_value.unwrap();
|
||||
|
||||
|
||||
assert!(actual_unit == expected_unit, "Expected unit {:?} but found unit {:?}", expected_unit, actual_unit);
|
||||
|
||||
let expected_value = expected_value.into();
|
||||
|
||||
match (actual_value, expected_value) {
|
||||
(Value::Number(Number::Complex(actual_c)), Value::Number(Number::Complex(expected_c))) => {
|
||||
assert!(
|
||||
(actual_c.re.is_infinite() && expected_c.re.is_infinite()) || (actual_c.re - expected_c.re).abs() < EPSILON,
|
||||
"Expected real part {}, but got {}",
|
||||
expected_c.re,
|
||||
actual_c.re
|
||||
);
|
||||
assert!(
|
||||
(actual_c.im.is_infinite() && expected_c.im.is_infinite()) || (actual_c.im - expected_c.im).abs() < EPSILON,
|
||||
"Expected imaginary part {}, but got {}",
|
||||
expected_c.im,
|
||||
actual_c.im
|
||||
);
|
||||
}
|
||||
(Value::Number(Number::Real(actual_f)), Value::Number(Number::Real(expected_f))) => {
|
||||
if actual_f.is_infinite() || expected_f.is_infinite() {
|
||||
assert!(
|
||||
actual_f.is_infinite() && expected_f.is_infinite() && actual_f == expected_f,
|
||||
"Expected infinite value {}, but got {}",
|
||||
expected_f,
|
||||
actual_f
|
||||
);
|
||||
} else if actual_f.is_nan() || expected_f.is_nan() {
|
||||
assert!(actual_f.is_nan() && expected_f.is_nan(), "Expected NaN, but got {}", actual_f);
|
||||
} else {
|
||||
assert!((actual_f - expected_f).abs() < EPSILON, "Expected {}, but got {}", expected_f, actual_f);
|
||||
}
|
||||
}
|
||||
// Handle mismatched types
|
||||
_ => panic!("Mismatched types: expected {:?}, got {:?}", expected_value, actual_value),
|
||||
}
|
||||
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
test_end_to_end! {
|
||||
// Basic arithmetic and units
|
||||
infix_addition: "5 + 5" => (10., Unit::BASE_UNIT),
|
||||
infix_subtraction_units: "5m - 3m" => (2., Unit::LENGTH),
|
||||
infix_multiplication_units: "4s * 4s" => (16., Unit { length: 0, mass: 0, time: 2 }),
|
||||
infix_division_units: "8m/2s" => (4., Unit::VELOCITY),
|
||||
|
||||
// Order of operations
|
||||
order_of_operations_negative_prefix: "-10 + 5" => (-5., Unit::BASE_UNIT),
|
||||
order_of_operations_add_multiply: "5+1*1+5" => (11., Unit::BASE_UNIT),
|
||||
order_of_operations_add_negative_multiply: "5+(-1)*1+5" => (9., Unit::BASE_UNIT),
|
||||
order_of_operations_sqrt: "sqrt25 + 11" => (16., Unit::BASE_UNIT),
|
||||
order_of_operations_sqrt_expression: "sqrt(25+11)" => (6., Unit::BASE_UNIT),
|
||||
|
||||
// Parentheses and nested expressions
|
||||
parentheses_nested_multiply: "(5 + 3) * (2 + 6)" => (64., Unit::BASE_UNIT),
|
||||
parentheses_mixed_operations: "2 * (3 + 5 * (2 + 1))" => (36., Unit::BASE_UNIT),
|
||||
parentheses_divide_add_multiply: "10 / (2 + 3) + (7 * 2)" => (16., Unit::BASE_UNIT),
|
||||
|
||||
// Square root and nested square root
|
||||
sqrt_chain_operations: "sqrt(16) + sqrt(9) * sqrt(4)" => (10., Unit::BASE_UNIT),
|
||||
sqrt_nested: "sqrt(sqrt(81))" => (3., Unit::BASE_UNIT),
|
||||
sqrt_divide_expression: "sqrt((25 + 11) / 9)" => (2., Unit::BASE_UNIT),
|
||||
|
||||
// Mixed square root and units
|
||||
sqrt_multiply_units: "sqrt(16) * 2g + 5g" => (13., Unit::MASS),
|
||||
sqrt_add_multiply: "sqrt(49) - 1 + 2 * 3" => (12., Unit::BASE_UNIT),
|
||||
sqrt_addition_multiply: "(sqrt(36) + 2) * 2" => (16., Unit::BASE_UNIT),
|
||||
|
||||
// Exponentiation
|
||||
exponent_single: "2^3" => (8., Unit::BASE_UNIT),
|
||||
exponent_mixed_operations: "2^3 + 4^2" => (24., Unit::BASE_UNIT),
|
||||
exponent_nested: "2^(3+1)" => (16., Unit::BASE_UNIT),
|
||||
|
||||
// Operations with negative values
|
||||
negative_units_add_multiply: "-5s + (-3 * 2)s" => (-11., Unit::TIME),
|
||||
negative_nested_parentheses: "-(5 + 3 * (2 - 1))" => (-8., Unit::BASE_UNIT),
|
||||
negative_sqrt_addition: "-(sqrt(16) + sqrt(9))" => (-7., Unit::BASE_UNIT),
|
||||
multiply_sqrt_subtract: "5 * 2 + sqrt(16) / 2 - 3" => (9., Unit::BASE_UNIT),
|
||||
add_multiply_subtract_sqrt: "4 + 3 * (2 + 1) - sqrt(25)" => (8., Unit::BASE_UNIT),
|
||||
add_sqrt_subtract_nested_multiply: "10 + sqrt(64) - (5 * (2 + 1))" => (3., Unit::BASE_UNIT),
|
||||
|
||||
// Mathematical constants
|
||||
constant_pi: "pi" => (std::f64::consts::PI, Unit::BASE_UNIT),
|
||||
constant_e: "e" => (std::f64::consts::E, Unit::BASE_UNIT),
|
||||
constant_phi: "phi" => (1.61803398875, Unit::BASE_UNIT),
|
||||
constant_tau: "tau" => (2.0 * std::f64::consts::PI, Unit::BASE_UNIT),
|
||||
constant_infinity: "inf" => (f64::INFINITY, Unit::BASE_UNIT),
|
||||
constant_infinity_symbol: "∞" => (f64::INFINITY, Unit::BASE_UNIT),
|
||||
multiply_pi: "2 * pi" => (2.0 * std::f64::consts::PI, Unit::BASE_UNIT),
|
||||
add_e_constant: "e + 1" => (std::f64::consts::E + 1.0, Unit::BASE_UNIT),
|
||||
multiply_phi_constant: "phi * 2" => (1.61803398875 * 2.0, Unit::BASE_UNIT),
|
||||
exponent_tau: "2^tau" => (2f64.powf(2.0 * std::f64::consts::PI), Unit::BASE_UNIT),
|
||||
infinity_subtract_large_number: "inf - 1000" => (f64::INFINITY, Unit::BASE_UNIT),
|
||||
|
||||
// Trigonometric functions
|
||||
trig_sin_pi: "sin(pi)" => (0.0, Unit::BASE_UNIT),
|
||||
trig_cos_zero: "cos(0)" => (1.0, Unit::BASE_UNIT),
|
||||
trig_tan_pi_div_four: "tan(pi/4)" => (1.0, Unit::BASE_UNIT),
|
||||
trig_sin_tau: "sin(tau)" => (0.0, Unit::BASE_UNIT),
|
||||
trig_cos_tau_div_two: "cos(tau/2)" => (-1.0, Unit::BASE_UNIT),
|
||||
}
|
||||
}
|
385
libraries/math-parser/src/parser.rs
Normal file
385
libraries/math-parser/src/parser.rs
Normal file
|
@ -0,0 +1,385 @@
|
|||
use std::num::{ParseFloatError, ParseIntError};
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use num_complex::ComplexFloat;
|
||||
use pest::{
|
||||
iterators::{Pair, Pairs},
|
||||
pratt_parser::{Assoc, Op, PrattParser},
|
||||
Parser,
|
||||
};
|
||||
use pest_derive::Parser;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{
|
||||
ast::{BinaryOp, Literal, Node, UnaryOp, Unit},
|
||||
context::EvalContext,
|
||||
value::{Complex, Number, Value},
|
||||
};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[grammar = "./grammer.pest"] // Point to the grammar file
|
||||
struct ExprParser;
|
||||
|
||||
lazy_static! {
|
||||
static ref PRATT_PARSER: PrattParser<Rule> = {
|
||||
PrattParser::new()
|
||||
.op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::sub, Assoc::Left))
|
||||
.op(Op::infix(Rule::mul, Assoc::Left) | Op::infix(Rule::div, Assoc::Left) | Op::infix(Rule::paren, Assoc::Left))
|
||||
.op(Op::infix(Rule::pow, Assoc::Right))
|
||||
.op(Op::postfix(Rule::fac) | Op::postfix(Rule::EOI))
|
||||
.op(Op::prefix(Rule::sqrt))
|
||||
.op(Op::prefix(Rule::neg))
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TypeError {
|
||||
#[error("Invalid BinOp: {0:?} {1:?} {2:?}")]
|
||||
InvalidBinaryOp(Unit, BinaryOp, Unit),
|
||||
|
||||
#[error("Invalid UnaryOp: {0:?}")]
|
||||
InvalidUnaryOp(Unit, UnaryOp),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ParseError {
|
||||
#[error("ParseIntError: {0}")]
|
||||
ParseInt(#[from] ParseIntError),
|
||||
#[error("ParseFloatError: {0}")]
|
||||
ParseFloat(#[from] ParseFloatError),
|
||||
|
||||
#[error("TypeError: {0}")]
|
||||
Type(#[from] TypeError),
|
||||
|
||||
#[error("PestError: {0}")]
|
||||
Pest(#[from] Box<pest::error::Error<Rule>>),
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub fn from_str(s: &str) -> Result<(Node, Unit), ParseError> {
|
||||
let pairs = ExprParser::parse(Rule::program, s).map_err(Box::new)?;
|
||||
let (node, metadata) = parse_expr(pairs)?;
|
||||
Ok((node, metadata.unit))
|
||||
}
|
||||
}
|
||||
|
||||
struct NodeMetadata {
|
||||
pub unit: Unit,
|
||||
}
|
||||
|
||||
impl NodeMetadata {
|
||||
pub fn new(unit: Unit) -> Self {
|
||||
Self { unit }
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_unit(pairs: Pairs<Rule>) -> Result<(Unit, f64), ParseError> {
|
||||
let mut scale = 1.0;
|
||||
let mut length = 0;
|
||||
let mut mass = 0;
|
||||
let mut time = 0;
|
||||
|
||||
for pair in pairs {
|
||||
println!("found rule: {:?}", pair.as_rule());
|
||||
match pair.as_rule() {
|
||||
Rule::nano => scale *= 1e-9,
|
||||
Rule::micro => scale *= 1e-6,
|
||||
Rule::milli => scale *= 1e-3,
|
||||
Rule::centi => scale *= 1e-2,
|
||||
Rule::deci => scale *= 1e-1,
|
||||
Rule::deca => scale *= 1e1,
|
||||
Rule::hecto => scale *= 1e2,
|
||||
Rule::kilo => scale *= 1e3,
|
||||
Rule::mega => scale *= 1e6,
|
||||
Rule::giga => scale *= 1e9,
|
||||
Rule::tera => scale *= 1e12,
|
||||
|
||||
Rule::meter => length = 1,
|
||||
Rule::gram => mass = 1,
|
||||
Rule::second => time = 1,
|
||||
|
||||
_ => unreachable!(), // All possible rules should be covered
|
||||
}
|
||||
}
|
||||
|
||||
Ok((Unit { length, mass, time }, scale))
|
||||
}
|
||||
|
||||
fn parse_const(pair: Pair<Rule>) -> Literal {
|
||||
match pair.as_rule() {
|
||||
Rule::infinity => Literal::Float(f64::INFINITY),
|
||||
Rule::imaginary_unit => Literal::Complex(Complex::new(0.0, 1.0)),
|
||||
Rule::pi => Literal::Float(std::f64::consts::PI),
|
||||
Rule::tau => Literal::Float(2.0 * std::f64::consts::PI),
|
||||
Rule::euler_number => Literal::Float(std::f64::consts::E),
|
||||
Rule::golden_ratio => Literal::Float(1.61803398875),
|
||||
_ => unreachable!("Unexpected constant: {:?}", pair),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_lit(mut pairs: Pairs<Rule>) -> Result<(Literal, Unit), ParseError> {
|
||||
let literal = match pairs.next() {
|
||||
Some(lit) => match lit.as_rule() {
|
||||
Rule::int => {
|
||||
let value = lit.as_str().parse::<i32>()? as f64;
|
||||
Literal::Float(value)
|
||||
}
|
||||
Rule::float => {
|
||||
let value = lit.as_str().parse::<f64>()?;
|
||||
Literal::Float(value)
|
||||
}
|
||||
Rule::unit => {
|
||||
let (unit, scale) = parse_unit(lit.into_inner())?;
|
||||
return Ok((Literal::Float(scale), unit));
|
||||
}
|
||||
rule => unreachable!("unexpected rule: {:?}", rule),
|
||||
},
|
||||
None => unreachable!("expected rule"), // No literal found
|
||||
};
|
||||
|
||||
if let Some(unit_pair) = pairs.next() {
|
||||
let unit_pairs = unit_pair.into_inner(); // Get the inner pairs for the unit
|
||||
let (unit, scale) = parse_unit(unit_pairs)?;
|
||||
|
||||
println!("found unit: {:?}", unit);
|
||||
|
||||
Ok((
|
||||
match literal {
|
||||
Literal::Float(num) => Literal::Float(num * scale),
|
||||
Literal::Complex(num) => Literal::Complex(num * scale),
|
||||
},
|
||||
unit,
|
||||
))
|
||||
} else {
|
||||
Ok((literal, Unit::BASE_UNIT))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_expr(pairs: Pairs<Rule>) -> Result<(Node, NodeMetadata), ParseError> {
|
||||
PRATT_PARSER
|
||||
.map_primary(|primary| {
|
||||
Ok(match primary.as_rule() {
|
||||
Rule::lit => {
|
||||
let (lit, unit) = parse_lit(primary.into_inner())?;
|
||||
|
||||
(Node::Lit(lit), NodeMetadata { unit })
|
||||
}
|
||||
Rule::fn_call => {
|
||||
let mut pairs = primary.into_inner();
|
||||
let name = pairs.next().expect("fn_call always has 2 children").as_str().to_string();
|
||||
|
||||
(
|
||||
Node::FnCall {
|
||||
name,
|
||||
expr: pairs.map(|p| parse_expr(p.into_inner()).map(|expr| expr.0)).collect::<Result<Vec<Node>, ParseError>>()?,
|
||||
},
|
||||
NodeMetadata::new(Unit::BASE_UNIT),
|
||||
)
|
||||
}
|
||||
Rule::constant => {
|
||||
let lit = parse_const(primary.into_inner().next().expect("constant should have atleast 1 child"));
|
||||
|
||||
(Node::Lit(lit), NodeMetadata::new(Unit::BASE_UNIT))
|
||||
}
|
||||
Rule::ident => {
|
||||
let name = primary.as_str().to_string();
|
||||
|
||||
(Node::Var(name), NodeMetadata::new(Unit::BASE_UNIT))
|
||||
}
|
||||
Rule::expr => parse_expr(primary.into_inner())?,
|
||||
Rule::float => {
|
||||
let value = primary.as_str().parse::<f64>()?;
|
||||
(Node::Lit(Literal::Float(value)), NodeMetadata::new(Unit::BASE_UNIT))
|
||||
}
|
||||
rule => unreachable!("unexpected rule: {:?}", rule),
|
||||
})
|
||||
})
|
||||
.map_prefix(|op, rhs| {
|
||||
let (rhs, rhs_metadata) = rhs?;
|
||||
let op = match op.as_rule() {
|
||||
Rule::neg => UnaryOp::Neg,
|
||||
Rule::sqrt => UnaryOp::Sqrt,
|
||||
|
||||
rule => unreachable!("unexpected rule: {:?}", rule),
|
||||
};
|
||||
|
||||
let node = Node::UnaryOp { expr: Box::new(rhs), op };
|
||||
let unit = rhs_metadata.unit;
|
||||
|
||||
let unit = if !unit.is_base() {
|
||||
match op {
|
||||
UnaryOp::Sqrt if unit.length % 2 == 0 && unit.mass % 2 == 0 && unit.time % 2 == 0 => Unit {
|
||||
length: unit.length / 2,
|
||||
mass: unit.mass / 2,
|
||||
time: unit.time / 2,
|
||||
},
|
||||
UnaryOp::Neg => unit,
|
||||
op => return Err(ParseError::Type(TypeError::InvalidUnaryOp(unit, op))),
|
||||
}
|
||||
} else {
|
||||
Unit::BASE_UNIT
|
||||
};
|
||||
|
||||
Ok((node, NodeMetadata::new(unit)))
|
||||
})
|
||||
.map_postfix(|lhs, op| {
|
||||
let (lhs_node, lhs_metadata) = lhs?;
|
||||
|
||||
let op = match op.as_rule() {
|
||||
Rule::EOI => return Ok((lhs_node, lhs_metadata)),
|
||||
Rule::fac => UnaryOp::Fac,
|
||||
rule => unreachable!("unexpected rule: {:?}", rule),
|
||||
};
|
||||
|
||||
if !lhs_metadata.unit.is_base() {
|
||||
return Err(ParseError::Type(TypeError::InvalidUnaryOp(lhs_metadata.unit, op)));
|
||||
}
|
||||
|
||||
Ok((Node::UnaryOp { expr: Box::new(lhs_node), op }, lhs_metadata))
|
||||
})
|
||||
.map_infix(|lhs, op, rhs| {
|
||||
let (lhs, lhs_metadata) = lhs?;
|
||||
let (rhs, rhs_metadata) = rhs?;
|
||||
|
||||
let op = match op.as_rule() {
|
||||
Rule::add => BinaryOp::Add,
|
||||
Rule::sub => BinaryOp::Sub,
|
||||
Rule::mul => BinaryOp::Mul,
|
||||
Rule::div => BinaryOp::Div,
|
||||
Rule::pow => BinaryOp::Pow,
|
||||
Rule::paren => BinaryOp::Mul,
|
||||
rule => unreachable!("unexpected rule: {:?}", rule),
|
||||
};
|
||||
|
||||
let (lhs_unit, rhs_unit) = (lhs_metadata.unit, rhs_metadata.unit);
|
||||
|
||||
let unit = match (!lhs_unit.is_base(), !rhs_unit.is_base()) {
|
||||
(true, true) => match op {
|
||||
BinaryOp::Mul => Unit {
|
||||
length: lhs_unit.length + rhs_unit.length,
|
||||
mass: lhs_unit.mass + rhs_unit.mass,
|
||||
time: lhs_unit.time + rhs_unit.time,
|
||||
},
|
||||
BinaryOp::Div => Unit {
|
||||
length: lhs_unit.length - rhs_unit.length,
|
||||
mass: lhs_unit.mass - rhs_unit.mass,
|
||||
time: lhs_unit.time - rhs_unit.time,
|
||||
},
|
||||
BinaryOp::Add | BinaryOp::Sub => {
|
||||
if lhs_unit == rhs_unit {
|
||||
lhs_unit
|
||||
} else {
|
||||
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit)));
|
||||
}
|
||||
}
|
||||
BinaryOp::Pow => {
|
||||
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit)));
|
||||
}
|
||||
},
|
||||
|
||||
(true, false) => match op {
|
||||
BinaryOp::Add | BinaryOp::Sub => return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))),
|
||||
BinaryOp::Pow => {
|
||||
//TODO: improve error type
|
||||
//TODO: support 1 / int
|
||||
if let Ok(Value::Number(Number::Real(val))) = rhs.eval(&EvalContext::default()) {
|
||||
if (val - val as i32 as f64).abs() <= f64::EPSILON {
|
||||
Unit {
|
||||
length: lhs_unit.length * val as i32,
|
||||
mass: lhs_unit.mass * val as i32,
|
||||
time: lhs_unit.time * val as i32,
|
||||
}
|
||||
} else {
|
||||
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT)));
|
||||
}
|
||||
} else {
|
||||
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT)));
|
||||
}
|
||||
}
|
||||
_ => lhs_unit,
|
||||
},
|
||||
(false, true) => match op {
|
||||
BinaryOp::Add | BinaryOp::Sub | BinaryOp::Pow => return Err(ParseError::Type(TypeError::InvalidBinaryOp(Unit::BASE_UNIT, op, rhs_unit))),
|
||||
_ => rhs_unit,
|
||||
},
|
||||
(false, false) => Unit::BASE_UNIT,
|
||||
};
|
||||
|
||||
let node = Node::BinOp {
|
||||
lhs: Box::new(lhs),
|
||||
op,
|
||||
rhs: Box::new(rhs),
|
||||
};
|
||||
|
||||
Ok((node, NodeMetadata::new(unit)))
|
||||
})
|
||||
.parse(pairs)
|
||||
}
|
||||
|
||||
//TODO: set up Unit test for Units
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
macro_rules! test_parser {
|
||||
($($name:ident: $input:expr => $expected:expr),* $(,)?) => {
|
||||
$(
|
||||
#[test]
|
||||
fn $name() {
|
||||
let result = Node::from_str($input).unwrap();
|
||||
assert_eq!(result.0, $expected);
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
test_parser! {
|
||||
test_parse_int_literal: "42" => Node::Lit(Literal::Float(42.0)),
|
||||
test_parse_float_literal: "3.14" => Node::Lit(Literal::Float(3.14)),
|
||||
test_parse_ident: "x" => Node::Var("x".to_string()),
|
||||
test_parse_unary_neg: "-42" => Node::UnaryOp {
|
||||
expr: Box::new(Node::Lit(Literal::Float(42.0))),
|
||||
op: UnaryOp::Neg,
|
||||
},
|
||||
test_parse_binary_add: "1 + 2" => Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(1.0))),
|
||||
op: BinaryOp::Add,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(2.0))),
|
||||
},
|
||||
test_parse_binary_mul: "3 * 4" => Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(3.0))),
|
||||
op: BinaryOp::Mul,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(4.0))),
|
||||
},
|
||||
test_parse_binary_pow: "2 ^ 3" => Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(2.0))),
|
||||
op: BinaryOp::Pow,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(3.0))),
|
||||
},
|
||||
test_parse_unary_sqrt: "sqrt(16)" => Node::UnaryOp {
|
||||
expr: Box::new(Node::Lit(Literal::Float(16.0))),
|
||||
op: UnaryOp::Sqrt,
|
||||
},
|
||||
test_parse_sqr_ident: "sqr(16)" => Node::FnCall {
|
||||
name:"sqr".to_string(),
|
||||
expr: vec![Node::Lit(Literal::Float(16.0))]
|
||||
},
|
||||
|
||||
test_parse_complex_expr: "(1 + 2) 3 - 4 ^ 2" => Node::BinOp {
|
||||
lhs: Box::new(Node::BinOp {
|
||||
lhs: Box::new(Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(1.0))),
|
||||
op: BinaryOp::Add,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(2.0))),
|
||||
}),
|
||||
op: BinaryOp::Mul,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(3.0))),
|
||||
}),
|
||||
op: BinaryOp::Sub,
|
||||
rhs: Box::new(Node::BinOp {
|
||||
lhs: Box::new(Node::Lit(Literal::Float(4.0))),
|
||||
op: BinaryOp::Pow,
|
||||
rhs: Box::new(Node::Lit(Literal::Float(2.0))),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
128
libraries/math-parser/src/value.rs
Normal file
128
libraries/math-parser/src/value.rs
Normal file
|
@ -0,0 +1,128 @@
|
|||
use std::f64::consts::PI;
|
||||
|
||||
use num_complex::ComplexFloat;
|
||||
|
||||
use crate::ast::{BinaryOp, UnaryOp};
|
||||
|
||||
pub type Complex = num_complex::Complex<f64>;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub enum Value {
|
||||
Number(Number),
|
||||
}
|
||||
|
||||
impl Value {
|
||||
pub fn from_f64(x: f64) -> Self {
|
||||
Self::Number(Number::Real(x))
|
||||
}
|
||||
|
||||
pub fn as_real(&self) -> Option<f64> {
|
||||
match self {
|
||||
Self::Number(Number::Real(val)) => Some(*val),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f64> for Value {
|
||||
fn from(x: f64) -> Self {
|
||||
Self::from_f64(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Display for Value {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Value::Number(num) => num.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub enum Number {
|
||||
Real(f64),
|
||||
Complex(Complex),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Number {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Number::Real(real) => real.fmt(f),
|
||||
Number::Complex(complex) => complex.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Number {
|
||||
pub fn binary_op(self, op: BinaryOp, other: Number) -> Number {
|
||||
match (self, other) {
|
||||
(Number::Real(lhs), Number::Real(rhs)) => {
|
||||
let result = match op {
|
||||
BinaryOp::Add => lhs + rhs,
|
||||
BinaryOp::Sub => lhs - rhs,
|
||||
BinaryOp::Mul => lhs * rhs,
|
||||
BinaryOp::Div => lhs / rhs,
|
||||
BinaryOp::Pow => lhs.powf(rhs),
|
||||
};
|
||||
Number::Real(result)
|
||||
}
|
||||
|
||||
(Number::Complex(lhs), Number::Complex(rhs)) => {
|
||||
let result = match op {
|
||||
BinaryOp::Add => lhs + rhs,
|
||||
BinaryOp::Sub => lhs - rhs,
|
||||
BinaryOp::Mul => lhs * rhs,
|
||||
BinaryOp::Div => lhs / rhs,
|
||||
BinaryOp::Pow => lhs.powc(rhs),
|
||||
};
|
||||
Number::Complex(result)
|
||||
}
|
||||
|
||||
(Number::Real(lhs), Number::Complex(rhs)) => {
|
||||
let lhs_complex = Complex::new(lhs, 0.0);
|
||||
let result = match op {
|
||||
BinaryOp::Add => lhs_complex + rhs,
|
||||
BinaryOp::Sub => lhs_complex - rhs,
|
||||
BinaryOp::Mul => lhs_complex * rhs,
|
||||
BinaryOp::Div => lhs_complex / rhs,
|
||||
BinaryOp::Pow => lhs_complex.powc(rhs),
|
||||
};
|
||||
Number::Complex(result)
|
||||
}
|
||||
|
||||
(Number::Complex(lhs), Number::Real(rhs)) => {
|
||||
let rhs_complex = Complex::new(rhs, 0.0);
|
||||
let result = match op {
|
||||
BinaryOp::Add => lhs + rhs_complex,
|
||||
BinaryOp::Sub => lhs - rhs_complex,
|
||||
BinaryOp::Mul => lhs * rhs_complex,
|
||||
BinaryOp::Div => lhs / rhs_complex,
|
||||
BinaryOp::Pow => lhs.powf(rhs),
|
||||
};
|
||||
Number::Complex(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unary_op(self, op: UnaryOp) -> Number {
|
||||
match self {
|
||||
Number::Real(real) => match op {
|
||||
UnaryOp::Neg => Number::Real(-real),
|
||||
UnaryOp::Sqrt => Number::Real(real.sqrt()),
|
||||
|
||||
UnaryOp::Fac => todo!("Implement factorial"),
|
||||
},
|
||||
|
||||
Number::Complex(complex) => match op {
|
||||
UnaryOp::Neg => Number::Complex(-complex),
|
||||
UnaryOp::Sqrt => Number::Complex(complex.sqrt()),
|
||||
|
||||
UnaryOp::Fac => todo!("Implement factorial"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_f64(x: f64) -> Self {
|
||||
Self::Real(x)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue