Transition to a trait-based node graph

This commit is contained in:
Dennis 2022-03-27 17:48:24 +02:00 committed by Keavon Chambers
parent a807a54c80
commit ab727de684
4 changed files with 162 additions and 131 deletions

18
node-graph/Cargo.lock generated
View file

@ -218,15 +218,6 @@ version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a"
[[package]]
name = "graph-ite"
version = "0.1.0"
dependencies = [
"graph-proc-macros",
"ra_ap_ide",
"ra_ap_ide_db",
]
[[package]]
name = "graph-proc-macros"
version = "0.1.0"
@ -359,6 +350,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "nodegraph-experiments"
version = "0.1.0"
dependencies = [
"graph-proc-macros",
"ra_ap_ide",
"ra_ap_ide_db",
]
[[package]]
name = "num_cpus"
version = "1.13.0"

View file

@ -1,7 +1,7 @@
[package]
name = "graph-ite"
name = "nodegraph-experiments"
version = "0.1.0"
edition = "2018"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View file

@ -1,7 +1,7 @@
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::punctuated::Punctuated;
use syn::{parse_macro_input, FnArg, ItemFn, Pat, Type};
use syn::{parse_macro_input, FnArg, ItemFn, Pat, Type};
fn extract_type(a: FnArg) -> Box<Type> {
match a {
@ -39,34 +39,51 @@ fn generate_to_string(parsed: ItemFn, string: String) -> TokenStream {
let whole_function = parsed.clone();
//let fn_body = parsed.block; // function body
let sig = parsed.sig; // function signature
//let vis = parsed.vis; // visibility, pub or not
//let vis = parsed.vis; // visibility, pub or not
let generics = sig.generics;
let fn_args = sig.inputs; // comma separated args
let fn_return_type = sig.output; // return type
let fn_name = sig.ident; // function name/identifier
let idents = extract_arg_idents(fn_args.clone());
let types = extract_arg_types(fn_args);
let types = types.iter().map(|t| t.to_token_stream()).collect::<Vec<_>>();
let idents = idents.iter().map(|t| t.to_token_stream()).collect::<Vec<_>>();
let types = types
.iter()
.map(|t| t.to_token_stream())
.collect::<Vec<_>>();
let idents = idents
.iter()
.map(|t| t.to_token_stream())
.collect::<Vec<_>>();
let node_fn_name = syn::Ident::new(&(fn_name.to_string() + "_node"), proc_macro2::Span::call_site()); // function name/identifier
let return_type_string = fn_return_type.to_token_stream().to_string().replace("->","");
let arg_type_string = types.iter().map(|t|t.to_string()).collect::<Vec<_>>().join(", ");
let node_fn_name = syn::Ident::new(
&(fn_name.to_string() + "_node"),
proc_macro2::Span::call_site(),
); // function name/identifier
let return_type_string = fn_return_type
.to_token_stream()
.to_string()
.replace("->", "");
let arg_type_string = types
.iter()
.map(|t| t.to_string())
.collect::<Vec<_>>()
.join(", ");
let error = format!("called {} with the wrong type", fn_name.to_string());
let x = quote! {
//#whole_function
fn #node_fn_name #generics() -> Node {
fn #node_fn_name #generics() -> Node<'static> {
Node { func: Box::new(move |x| {
let args = x.downcast::<(#(#types,)*)>().expect(#error);
let (#(#idents,)*) = *args;
#whole_function
let args = x.downcast::<(#(#types,)*)>().expect(#error);
let (#(#idents,)*) = *args;
#whole_function
Box::new(#fn_name(#(#idents,)*))
}),
code: #string.to_string(),
return_type: #return_type_string.trim().to_string(),
args: format!("({})",#arg_type_string.trim()),
Box::new(#fn_name(#(#idents,)*))
}),
code: #string.to_string(),
return_type: #return_type_string.trim().to_string(),
args: format!("({})",#arg_type_string.trim()),
position: (0., 0.),
}
}

View file

@ -1,117 +1,131 @@
use std::any::Any;
use std::{any::Any, iter::Sum, ops::Add};
type Function = Box<dyn Fn(Box<dyn Any>) -> Box<dyn Any>>;
struct Node {
func: Function,
code: String,
return_type: String,
args: String,
}
impl Node {
fn eval<T: 'static, U: 'static>(&self, t: T) -> U {
*(self.func)(Box::new(t)).downcast::<U>().unwrap()
}
#[allow(unused)]
fn id(self) -> Self {
self
}
}
impl std::ops::Mul<Self> for Node {
type Output = Self;
fn mul(self, other: Self) -> Self {
node_compose(self, other)
}
}
pub fn compose<F: 'static, G: 'static, Fv, Gv, V>(g: G, f: F) -> Box<dyn Fn(Fv) -> V>
pub struct InsertAfterNth<A>
where
F: Fn(Fv) -> Gv,
G: Fn(Gv) -> V,
A: Iterator,
{
Box::new(move |x| g(f(x)))
n: usize,
iter: A,
value: Option<A::Item>,
}
fn node_compose(g: Node, f: Node) -> Node {
#[rustfmt::skip]
let Node { func: ff, code: fc, args: fa, return_type: fr} = f;
#[rustfmt::skip]
let Node { func, code, args, return_type } = g;
assert_eq!(args, fr);
Node {
func: Box::new(move |x| func(ff(x))),
code: fc + code.as_str(), // temporary TODO: replace
return_type,
args: fa,
}
}
#[graph_proc_macros::to_node]
fn id<T:'static>(t: T) -> T {
t
}
impl<A> Iterator for InsertAfterNth<A>
where
A: Iterator,
{
type Item = A::Item;
#[graph_proc_macros::to_node]
fn gen_int() -> (u32, u32) {
(42, 43)
}
#[graph_proc_macros::to_node]
fn format_int(x: u32, y: u32) -> String {
x.to_string() + &y.to_string()
}
#[graph_proc_macros::to_node]
fn curry_first_u32(x: u32, node: Node) -> Node {
assert_eq!(node.args[1..].split(",").next(), Some("u32"));
curry_first_arg_node::<u32>().eval((x, node))
}
#[graph_proc_macros::to_node]
fn curry_first_arg<T: 'static + Clone>(x: T, node: Node) -> Node {
node_after_fn_node().eval::<(Node, Function), Node>((
node,
Box::new(move |y: Box<dyn Any>| {
Box::new((x.clone(), *y.downcast::<T>().unwrap())) as Box<dyn Any>
}) ,
))
}
#[graph_proc_macros::to_node]
fn compose_node(g: Node, f: Node) -> Node {
node_compose(g, f)
}
#[graph_proc_macros::to_node]
fn node_after_fn(g: Node, f: Box<dyn Fn(Box<dyn Any>) -> Box<dyn Any>>) -> Node {
let Node {
func, return_type, ..
} = g;
Node {
func: compose(func, f),
code: "unimplemented".to_string(),
return_type,
args: "".to_string(),
fn next(&mut self) -> Option<Self::Item> {
match self.n {
1.. => {
self.n -= 1;
self.iter.next()
}
0 if self.value.is_some() => self.value.take(),
_ => self.iter.next(),
}
}
}
#[graph_proc_macros::to_node]
fn node_from_fn(f: Box<dyn Fn(Box<dyn Any>) -> Box<dyn Any>>) -> Node {
node_after_fn_node().eval((id_node::<Box<dyn Any>>, f))
pub fn insert_after_nth<A>(n: usize, iter: A, value: A::Item) -> InsertAfterNth<A>
where
A: Iterator,
{
InsertAfterNth {
n,
iter,
value: Some(value),
}
}
trait Node<O> {
fn eval<'a>(&'a self, input: impl Iterator<Item = &'a dyn Any>) -> O;
// fn source code
// positon
}
struct IntNode;
impl Node<u32> for IntNode {
fn eval<'a>(&'a self, _input: impl Iterator<Item = &'a dyn Any>) -> u32 {
42
}
}
struct AddNode;
impl<T: Sum + 'static + Copy> Node<T> for AddNode {
fn eval<'a>(&'a self, input: impl Iterator<Item = &'a dyn Any>) -> T {
input
.take(2)
.map(|x| *(x.downcast_ref::<T>().unwrap()))
.sum::<T>()
}
}
struct CurryNthArgNode<'a, T: Node<O>, A, O, const N: usize> {
node: &'a T,
arg: A,
_phantom_data: std::marker::PhantomData<O>,
}
impl<'a, T: Node<O>, A: 'static, O, const N: usize> Node<O> for CurryNthArgNode<'a, T, A, O, N> {
fn eval<'b>(&'b self, input: impl Iterator<Item = &'b dyn Any>) -> O {
self.node
.eval(insert_after_nth(N, input, &self.arg as &dyn Any))
}
}
impl<'a, T: Node<O>, A: 'static, O, const N: usize> CurryNthArgNode<'a, T, A, O, N> {
fn new(node: &'a T, arg: A) -> Self {
CurryNthArgNode::<'a, T, A, O, N> {
node,
arg,
_phantom_data: std::marker::PhantomData::default(),
}
}
}
struct ComposeNode<'a, L, R, B>
where
L: Node<B>,
{
first: &'a L,
second: &'a R,
_phantom_data: std::marker::PhantomData<B>,
}
impl<'a, B: 'static, L, R, O> Node<O> for ComposeNode<'a, L, R, B>
where
L: Node<B>,
R: Node<O>,
{
fn eval<'b>(&'b self, input: impl Iterator<Item = &'b dyn Any>) -> O {
let curry = CurryNthArgNode::<'a, R, B, O, 0> {
node: self.second,
arg: self.first.eval(input),
_phantom_data: std::marker::PhantomData::default(),
};
let result: O = curry.eval([].into_iter());
result
}
}
impl<'a, L, R, B: 'static> ComposeNode<'a, L, R, B>
where
L: Node<B>,
{
fn new(first: &'a L, second: &'a R) -> Self {
ComposeNode::<'a, L, R, B> {
first,
second,
_phantom_data: std::marker::PhantomData::default(),
}
}
}
fn main() {
println!("{:?}",(format_int_node() * gen_int_node()).eval::<_, String>(()));
println!(
"{:?}",
curry_first_u32_node()
.eval::<(u32, Node), Node>((3, format_int_node()))
.eval::<u32, String>(43)
);
println!(
"{:?}",
curry_first_arg_node::<u32>()
.eval::<(u32, Node), Node>((3, format_int_node()))
.eval::<u32, String>(43)
);
let int = IntNode;
let curry: CurryNthArgNode<_, u32, u32, 0> =
CurryNthArgNode::new(&AddNode, int.eval(std::iter::empty()));
let composition = ComposeNode::new(&curry, &curry);
let curry: CurryNthArgNode<_, u32, _, 0> = CurryNthArgNode::new(&composition, 10);
println!("{}", curry.eval(std::iter::empty()))
}