Simplify node trait definition (#1146)

* Simplify node trait
This commit is contained in:
Dennis Kobert 2023-04-17 23:42:22 +02:00 committed by Keavon Chambers
parent 1d6c4f13dd
commit 76c754d38a
11 changed files with 47 additions and 105 deletions

View file

@ -5,7 +5,7 @@ pub struct FnNode<T: Fn(I) -> O, I, O>(T, PhantomData<(I, O)>);
impl<'i, T: Fn(I) -> O + 'i, O: 'i, I: 'i> Node<'i, I> for FnNode<T, I, O> {
type Output = O;
fn eval<'s: 'i>(&'s self, input: I) -> Self::Output {
fn eval(&'i self, input: I) -> Self::Output {
self.0(input)
}
}
@ -19,11 +19,11 @@ impl<T: Fn(I) -> O, I, O> FnNode<T, I, O> {
pub struct FnNodeWithState<'i, T: Fn(I, &'i State) -> O, I, O, State: 'i>(T, State, PhantomData<(&'i O, I)>);
impl<'i, I: 'i, O: 'i, State, T: Fn(I, &'i State) -> O + 'i> Node<'i, I> for FnNodeWithState<'i, T, I, O, State> {
type Output = O;
fn eval<'s: 'i>(&'s self, input: I) -> Self::Output {
fn eval(&'i self, input: I) -> Self::Output {
(self.0)(input, &self.1)
}
}
impl<'i, 's: 'i, I, O, State, T: Fn(I, &'i State) -> O> FnNodeWithState<'i, T, I, O, State> {
impl<'i, I, O, State, T: Fn(I, &'i State) -> O> FnNodeWithState<'i, T, I, O, State> {
pub fn new(f: T, state: State) -> Self {
FnNodeWithState(f, state, PhantomData)
}

View file

@ -33,7 +33,7 @@ pub use raster::Color;
// pub trait Node: for<'n> NodeIO<'n> {
pub trait Node<'i, Input: 'i>: 'i {
type Output: 'i;
fn eval<'s: 'i>(&'s self, input: Input) -> Self::Output;
fn eval(&'i self, input: Input) -> Self::Output;
fn reset(self: Pin<&mut Self>) {}
}
@ -79,14 +79,14 @@ where
/*impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'n> Node<'n, I, Output = O> {
type Output = O;
fn eval<'s: 'i>(&'s self, input: I) -> Self::Output {
fn eval(&'i self, input: I) -> Self::Output {
(**self).eval(input)
}
}*/
impl<'i, 'n: 'i, I: 'i, O: 'i> Node<'i, I> for &'n dyn for<'a> Node<'a, I, Output = O> {
impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'a> Node<'a, I, Output = O> {
type Output = O;
fn eval<'s: 'i>(&'s self, input: I) -> Self::Output {
fn eval(&'i self, input: I) -> Self::Output {
(**self).eval(input)
}
}
@ -97,7 +97,7 @@ use dyn_any::StaticType;
impl<'i, I: 'i, O: 'i> Node<'i, I> for Pin<Box<dyn for<'a> Node<'a, I, Output = O> + 'i>> {
type Output = O;
fn eval<'s: 'i>(&'s self, input: I) -> Self::Output {
fn eval(&'i self, input: I) -> Self::Output {
(**self).eval(input)
}
}

View file

@ -8,7 +8,7 @@ pub struct AddNode;
impl<'i, L: Add<R, Output = O> + 'i, R: 'i, O: 'i> Node<'i, (L, R)> for AddNode {
type Output = <L as Add<R>>::Output;
fn eval<'s: 'i>(&'s self, input: (L, R)) -> Self::Output {
fn eval(&'i self, input: (L, R)) -> Self::Output {
input.0 + input.1
}
}
@ -30,58 +30,6 @@ where
first + second
}
/*
#[cfg(feature = "std")]
pub mod dynamic {
use super::*;
// Unfortunatly we can't impl the AddNode as we get
// `upstream crates may add a new impl of trait `core::ops::Add` for type `alloc::boxed::Box<(dyn dyn_any::DynAny<'_> + 'static)>` in future versions`
pub struct DynamicAddNode;
// Alias for a dynamic type
pub type Dynamic<'a> = alloc::boxed::Box<dyn dyn_any::DynAny<'a> + 'a>;
/// Resolves the dynamic types for a dynamic node.
///
/// Macro uses format `BaseNode => (arg1: u32) (arg1: i32)`
macro_rules! resolve_dynamic_types {
($node:ident => $(($($arg:ident : $t:ty),*))*) => {
$(
// Check for each possible set of arguments if their types match the arguments given
if $(core::any::TypeId::of::<$t>() == $arg.type_id())&&* {
// Cast the arguments and then call the inner node
alloc::boxed::Box::new($node.eval(($(*dyn_any::downcast::<$t>($arg).unwrap()),*)) ) as Dynamic
}
)else*
else {
panic!("Unhandled type"); // TODO: Exit neatly (although this should probably not happen)
}
};
}
impl<'i> Node<(Dynamic<'i>, Dynamic<'i>)> for DynamicAddNode {
type Output = Dynamic<'i>;
fn eval<'s: 'i>(self, (left, right): (Dynamic, Dynamic)) -> Self::Output {
resolve_dynamic_types! { AddNode =>
(left: usize, right: usize)
(left: u8, right: u8)
(left: u16, right: u16)
(left: u32, right: u32)
(left: u64, right: u64)
(left: u128, right: u128)
(left: isize, right: isize)
(left: i8, right: i8)
(left: i16, right: i16)
(left: i32, right: i32)
(left: i64, right: i64)
(left: i128, right: i128)
(left: f32, right: f32)
(left: f64, right: f64) }
}
}
}*/
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct SomeNode;
#[node_macro::node_fn(SomeNode)]
@ -93,7 +41,7 @@ fn some<T>(input: T) -> Option<T> {
pub struct CloneNode<O>(PhantomData<O>);
impl<'i, 'n: 'i, O: Clone + 'i> Node<'i, &'n O> for CloneNode<O> {
type Output = O;
fn eval<'s: 'i>(&'s self, input: &'i O) -> Self::Output {
fn eval(&'i self, input: &'i O) -> Self::Output {
input.clone()
}
}
@ -107,7 +55,7 @@ impl<O> CloneNode<O> {
pub struct FstNode;
impl<'i, L: 'i, R: 'i> Node<'i, (L, R)> for FstNode {
type Output = L;
fn eval<'s: 'i>(&'s self, input: (L, R)) -> Self::Output {
fn eval(&'i self, input: (L, R)) -> Self::Output {
input.0
}
}
@ -122,7 +70,7 @@ impl FstNode {
pub struct SndNode;
impl<'i, L: 'i, R: 'i> Node<'i, (L, R)> for SndNode {
type Output = R;
fn eval<'s: 'i>(&'s self, input: (L, R)) -> Self::Output {
fn eval(&'i self, input: (L, R)) -> Self::Output {
input.1
}
}
@ -137,7 +85,7 @@ impl SndNode {
pub struct SwapNode;
impl<'i, L: 'i, R: 'i> Node<'i, (L, R)> for SwapNode {
type Output = (R, L);
fn eval<'s: 'i>(&'s self, input: (L, R)) -> Self::Output {
fn eval(&'i self, input: (L, R)) -> Self::Output {
(input.1, input.0)
}
}
@ -152,7 +100,7 @@ impl SwapNode {
pub struct DupNode;
impl<'i, O: Clone + 'i> Node<'i, O> for DupNode {
type Output = (O, O);
fn eval<'s: 'i>(&'s self, input: O) -> Self::Output {
fn eval(&'i self, input: O) -> Self::Output {
(input.clone(), input)
}
}
@ -167,7 +115,7 @@ impl DupNode {
pub struct IdNode;
impl<'i, O: 'i> Node<'i, O> for IdNode {
type Output = O;
fn eval<'s: 'i>(&'s self, input: O) -> Self::Output {
fn eval(&'i self, input: O) -> Self::Output {
input
}
}
@ -186,7 +134,7 @@ where
N: for<'n> Node<'n, I, Output = O>,
{
type Output = O;
fn eval<'s: 'i>(&'s self, input: I) -> Self::Output {
fn eval(&'i self, input: I) -> Self::Output {
self.0.eval(input)
}
}

View file

@ -200,7 +200,7 @@ pub struct MapNode<MapFn> {
}
#[node_macro::node_fn(MapNode)]
fn map_node<_Iter: Iterator, MapFnNode>(input: _Iter, map_fn: &'any_input MapFnNode) -> MapFnIterator<'input, 'input, _Iter, MapFnNode>
fn map_node<_Iter: Iterator, MapFnNode>(input: _Iter, map_fn: &'any_input MapFnNode) -> MapFnIterator<'input, _Iter, MapFnNode>
where
MapFnNode: for<'any_input> Node<'any_input, _Iter::Item>,
{
@ -208,40 +208,34 @@ where
}
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct MapFnIterator<'i, 's, Iter, MapFn> {
pub struct MapFnIterator<'i, Iter, MapFn> {
iter: Iter,
map_fn: &'s MapFn,
_phantom: core::marker::PhantomData<&'i &'s ()>,
map_fn: &'i MapFn,
}
impl<'i, 's: 'i, Iter: Debug, MapFn> Debug for MapFnIterator<'i, 's, Iter, MapFn> {
impl<'i, Iter: Debug, MapFn> Debug for MapFnIterator<'i, Iter, MapFn> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("MapFnIterator").field("iter", &self.iter).field("map_fn", &"MapFn").finish()
}
}
impl<'i, 's: 'i, Iter: Clone, MapFn> Clone for MapFnIterator<'i, 's, Iter, MapFn> {
impl<'i, Iter: Clone, MapFn> Clone for MapFnIterator<'i, Iter, MapFn> {
fn clone(&self) -> Self {
Self {
iter: self.iter.clone(),
map_fn: self.map_fn,
_phantom: core::marker::PhantomData,
}
}
}
impl<'i, 's: 'i, Iter: Copy, MapFn> Copy for MapFnIterator<'i, 's, Iter, MapFn> {}
impl<'i, Iter: Copy, MapFn> Copy for MapFnIterator<'i, Iter, MapFn> {}
impl<'i, 's: 'i, Iter, MapFn> MapFnIterator<'i, 's, Iter, MapFn> {
pub fn new(iter: Iter, map_fn: &'s MapFn) -> Self {
Self {
iter,
map_fn,
_phantom: core::marker::PhantomData,
}
impl<'i, Iter, MapFn> MapFnIterator<'i, Iter, MapFn> {
pub fn new(iter: Iter, map_fn: &'i MapFn) -> Self {
Self { iter, map_fn }
}
}
impl<'i, 's: 'i, I: Iterator + 's, F> Iterator for MapFnIterator<'i, 's, I, F>
impl<'i, I: Iterator + 'i, F> Iterator for MapFnIterator<'i, I, F>
where
F: Node<'i, I::Item> + 'i,
Self: 'i,
@ -318,7 +312,7 @@ where
{
type Output = ImageWindowIterator<'input, P>;
#[inline]
fn eval<'node: 'input>(&'node self, input: u32) -> Self::Output {
fn eval(&'input self, input: u32) -> Self::Output {
let radius = self.radius.eval(());
let image = self.image.eval(());
{

View file

@ -10,7 +10,7 @@ pub struct BrightnessContrastLegacyMapperNode {
impl<'i> Node<'i, Color> for BrightnessContrastLegacyMapperNode {
type Output = Color;
fn eval<'s: 'i>(&'s self, color: Color) -> Color {
fn eval(&'i self, color: Color) -> Color {
let color = color.to_gamma_srgb();
let color = color.map_rgb(|c| (c + c * self.contrast + self.combined).clamp(0., 1.));
@ -46,7 +46,7 @@ pub struct BrightnessContrastMapperNode {
impl<'i> Node<'i, Color> for BrightnessContrastMapperNode {
type Output = Color;
fn eval<'s: 'i>(&'s self, color: Color) -> Color {
fn eval(&'i self, color: Color) -> Color {
let color = color.to_gamma_srgb();
let color = color.map_rgb(|c| {

View file

@ -14,7 +14,7 @@ where
Second: for<'a> Node<'a, <First as Node<'a, Input>>::Output> + 'i,
{
type Output = <Second as Node<'i, <First as Node<'i, Input>>::Output>>::Output;
fn eval<'s: 'i>(&'s self, input: Input) -> Self::Output {
fn eval(&'i self, input: Input) -> Self::Output {
let arg = self.first.eval(input);
self.second.eval(arg)
}
@ -64,7 +64,7 @@ where
Root: Node<'i, I>,
{
type Output = (Input, Root::Output);
fn eval<'s: 'i>(&'s self, input: Input) -> Self::Output {
fn eval(&'i self, input: Input) -> Self::Output {
let arg = self.0.eval(I::from(()));
(input, arg)
}

View file

@ -8,7 +8,7 @@ pub struct IntNode<const N: u32>;
impl<'i, const N: u32> Node<'i, ()> for IntNode<N> {
type Output = u32;
fn eval<'s: 'i>(&'s self, _input: ()) -> Self::Output {
fn eval(&'i self, _input: ()) -> Self::Output {
N
}
}
@ -18,7 +18,7 @@ pub struct ValueNode<T>(pub T);
impl<'i, T: 'i> Node<'i, ()> for ValueNode<T> {
type Output = &'i T;
fn eval<'s: 'i>(&'s self, _input: ()) -> Self::Output {
fn eval(&'i self, _input: ()) -> Self::Output {
&self.0
}
}
@ -46,7 +46,7 @@ pub struct ClonedNode<T: Clone>(pub T);
impl<'i, T: Clone + 'i> Node<'i, ()> for ClonedNode<T> {
type Output = T;
fn eval<'s: 'i>(&'s self, _input: ()) -> Self::Output {
fn eval(&'i self, _input: ()) -> Self::Output {
self.0.clone()
}
}
@ -69,7 +69,7 @@ pub struct DefaultNode<T>(PhantomData<T>);
impl<'i, T: Default + 'i> Node<'i, ()> for DefaultNode<T> {
type Output = T;
fn eval<'s: 'i>(&self, _input: ()) -> Self::Output {
fn eval(&'i self, _input: ()) -> Self::Output {
T::default()
}
}
@ -87,7 +87,7 @@ pub struct ForgetNode;
impl<'i, T: 'i> Node<'i, T> for ForgetNode {
type Output = ();
fn eval<'s: 'i>(&self, _input: T) -> Self::Output {}
fn eval(&'i self, _input: T) -> Self::Output {}
}
impl ForgetNode {

View file

@ -213,7 +213,7 @@ pub struct UpcastNode {
impl<'input> Node<'input, Box<dyn DynAny<'input> + 'input>> for UpcastNode {
type Output = Box<dyn DynAny<'input> + 'input>;
fn eval<'s: 'input>(&'s self, _: Box<dyn DynAny<'input> + 'input>) -> Self::Output {
fn eval(&'input self, _: Box<dyn DynAny<'input> + 'input>) -> Self::Output {
self.value.clone().to_any()
}
}

View file

@ -27,7 +27,7 @@ where
N: for<'any_input> Node<'any_input, _I, Output = &'any_input _O>,
{
type Output = Any<'input>;
fn eval<'node: 'input>(&'node self, input: Any<'input>) -> Self::Output {
fn eval(&'input self, input: Any<'input>) -> Self::Output {
{
let node_name = core::any::type_name::<N>();
let input: Box<_I> = dyn_any::downcast(input).unwrap_or_else(|e| panic!("DynAnyRefNode Input, {e} in:\n{node_name}"));
@ -54,7 +54,7 @@ where
N: for<'any_input> Node<'any_input, &'any_input _I, Output = _O>,
{
type Output = Any<'input>;
fn eval<'node: 'input>(&'node self, input: Any<'input>) -> Self::Output {
fn eval(&'input self, input: Any<'input>) -> Self::Output {
{
let node_name = core::any::type_name::<N>();
let input: Box<&_I> = dyn_any::downcast(input).unwrap_or_else(|e| panic!("DynAnyInRefNode Input, {e} in:\n{node_name}"));
@ -113,7 +113,7 @@ pub struct DowncastBothNode<'a, I, O> {
impl<'n: 'input, 'input, O: 'input + StaticType, I: 'input + StaticType> Node<'input, I> for DowncastBothNode<'n, I, O> {
type Output = O;
#[inline]
fn eval<'node: 'input>(&'node self, input: I) -> Self::Output {
fn eval(&'input self, input: I) -> Self::Output {
{
let input = Box::new(input);
let out = dyn_any::downcast(self.node.eval(input)).unwrap_or_else(|e| panic!("DowncastBothNode Input {e}"));
@ -140,7 +140,7 @@ pub struct DowncastBothRefNode<'a, I, O> {
impl<'n: 'input, 'input, O: 'input + StaticType, I: 'input + StaticType> Node<'input, I> for DowncastBothRefNode<'n, I, O> {
type Output = &'input O;
#[inline]
fn eval<'node: 'input>(&'node self, input: I) -> Self::Output {
fn eval(&'input self, input: I) -> Self::Output {
{
let input = Box::new(input);
let out: Box<&_> = dyn_any::downcast::<&O>(self.node.eval(input)).unwrap_or_else(|e| panic!("DowncastBothRefNode Input {e}"));
@ -161,7 +161,7 @@ pub struct ComposeTypeErased<'a> {
impl<'i, 'a: 'i> Node<'i, Any<'i>> for ComposeTypeErased<'a> {
type Output = Any<'i>;
fn eval<'s: 'i>(&'s self, input: Any<'i>) -> Self::Output {
fn eval(&'i self, input: Any<'i>) -> Self::Output {
let arg = self.first.eval(input);
self.second.eval(arg)
}

View file

@ -19,7 +19,7 @@ where
CachedNode: for<'any_input> Node<'any_input, I, Output = T>,
{
type Output = &'i T;
fn eval<'s: 'i>(&'s self, input: I) -> Self::Output {
fn eval(&'i self, input: I) -> Self::Output {
let mut hasher = Xxh3::new();
input.hash(&mut hasher);
let hash = hasher.finish();
@ -63,7 +63,7 @@ pub struct LetNode<T> {
}
impl<'i, T: 'i + Hash> Node<'i, Option<T>> for LetNode<T> {
type Output = &'i T;
fn eval<'s: 'i>(&'s self, input: Option<T>) -> Self::Output {
fn eval(&'i self, input: Option<T>) -> Self::Output {
match input {
Some(input) => {
let mut hasher = Xxh3::new();
@ -108,7 +108,7 @@ where
Input: Node<'i, ()>,
{
type Output = <Input>::Output;
fn eval<'s: 'i>(&'s self, _: &'i T) -> Self::Output {
fn eval(&'i self, _: &'i T) -> Self::Output {
self.input.eval(())
}
}
@ -131,7 +131,7 @@ where
Let: for<'a> Node<'a, Option<T>, Output = &'a T>,
{
type Output = &'i T;
fn eval<'s: 'i>(&'s self, _: ()) -> Self::Output {
fn eval(&'i self, _: ()) -> Self::Output {
self.let_node.eval(None)
}
}

View file

@ -146,7 +146,7 @@ pub fn node_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
{
type Output = #output;
#[inline]
fn eval<'node: 'input>(&'node self, #primary_input_mutability #primary_input_ident: #primary_input_ty) -> Self::Output {
fn eval(&'input self, #primary_input_mutability #primary_input_ident: #primary_input_ty) -> Self::Output {
#(
let #parameter_mutability #parameter_idents = self.#parameter_idents.eval(());
)*