Add type checking for parameter inputs (#1045)

This commit is contained in:
Dennis Kobert 2023-02-20 01:37:13 +01:00 committed by Keavon Chambers
parent 98f172414a
commit a993938d80
4 changed files with 41 additions and 33 deletions

View file

@ -51,7 +51,7 @@ where
core::any::type_name::<Self::Output>()
}
#[cfg(feature = "alloc")]
fn to_node_io(&self, parameters: Vec<Type>) -> NodeIOTypes {
fn to_node_io(&self, parameters: Vec<(Type, Type)>) -> NodeIOTypes {
NodeIOTypes {
input: concrete!(<Input as StaticType>::Static),
output: concrete!(<Self::Output as StaticType>::Static),

View file

@ -9,11 +9,11 @@ pub use std::borrow::Cow;
pub struct NodeIOTypes {
pub input: Type,
pub output: Type,
pub parameters: Vec<Type>,
pub parameters: Vec<(Type, Type)>,
}
impl NodeIOTypes {
pub fn new(input: Type, output: Type, parameters: Vec<Type>) -> Self {
pub fn new(input: Type, output: Type, parameters: Vec<(Type, Type)>) -> Self {
Self { input, output, parameters }
}
}

View file

@ -362,9 +362,9 @@ impl TypingContext {
self.inferred
.get(id)
.ok_or(format!("Inferring type of {node_id} depends on {id} which is not present in the typing context"))
.map(|node| node.output.clone())
.map(|node| (node.input.clone(), node.output.clone()))
})
.collect::<Result<Vec<Type>, String>>()?,
.collect::<Result<Vec<(Type, Type)>, String>>()?,
};
// Get the node input type from the proto node declaration
@ -384,7 +384,7 @@ impl TypingContext {
if matches!(input, Type::Generic(_)) {
return Err(format!("Generic types are not supported as inputs yet {:?} occured in {:?}", &input, node.identifier));
}
if parameters.iter().any(|p| matches!(p, Type::Generic(_))) {
if parameters.iter().any(|p| matches!(p.1, Type::Generic(_))) {
return Err(format!("Generic types are not supported in parameters: {:?} occured in {:?}", parameters, node.identifier));
}
let covariant = |output, input| match (&output, &input) {
@ -397,7 +397,13 @@ impl TypingContext {
// List of all implementations that match the input and parameter types
let valid_output_types = impls
.keys()
.filter(|node_io| covariant(input.clone(), node_io.input.clone()) && parameters.iter().zip(node_io.parameters.iter()).all(|(p1, p2)| covariant(p1.clone(), p2.clone())))
.filter(|node_io| {
covariant(input.clone(), node_io.input.clone())
&& parameters
.iter()
.zip(node_io.parameters.iter())
.all(|(p1, p2)| covariant(p1.0.clone(), p2.0.clone()) && covariant(p1.1.clone(), p2.1.clone()))
})
.collect::<Vec<_>>();
// Attempt to substitute generic types with concrete types and save the list of results
@ -445,7 +451,7 @@ impl TypingContext {
/// Returns a list of all generic types used in the node
fn collect_generics(types: &NodeIOTypes) -> Vec<Cow<'static, str>> {
let inputs = [&types.input].into_iter().chain(types.parameters.iter());
let inputs = [&types.input].into_iter().chain(types.parameters.iter().map(|(_, x)| x));
let mut generics = inputs
.filter_map(|t| match t {
Type::Generic(out) => Some(out.clone()),
@ -460,8 +466,10 @@ fn collect_generics(types: &NodeIOTypes) -> Vec<Cow<'static, str>> {
}
/// Checks if a generic type can be substituted with a concrete type and returns the concrete type
fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[Type], generic: &str) -> Result<Type, String> {
let inputs = [(&types.input, input)].into_iter().chain(types.parameters.iter().zip(parameters.iter()));
fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[(Type, Type)], generic: &str) -> Result<Type, String> {
let inputs = [(&types.input, input)]
.into_iter()
.chain(types.parameters.iter().map(|(_, x)| x).zip(parameters.iter().map(|(_, x)| x)));
let mut concrete_inputs = inputs.filter(|(ni, _)| matches!(ni, Type::Generic(input) if generic == input));
let (_, out_ty) = concrete_inputs
.next()

View file

@ -45,7 +45,7 @@ macro_rules! register_node {
let node = <$path>::new($(
graphene_std::any::input_node::<$type>(_node)
),*);
let params = vec![$(concrete!($type)),*];
let params = vec![$((concrete!(()), concrete!($type))),*];
let mut node_io = <$path as NodeIO<'_, $input>>::to_node_io(&node, params);
node_io.input = concrete!(<$input as StaticType>::Static);
node_io
@ -72,7 +72,7 @@ macro_rules! raster_node {
Box::pin(any)
},
{
let params = vec![$(concrete!($type)),*];
let params = vec![$((concrete!(()), concrete!($type))),*];
NodeIOTypes::new(concrete!(Image), concrete!(Image), params)
},
)
@ -110,7 +110,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
let node = ComposeTypeErased::new(args[0], args[1]);
node.into_type_erased()
},
NodeIOTypes::new(generic!(T), generic!(U), vec![generic!(V), generic!(U)]),
NodeIOTypes::new(generic!(T), generic!(U), vec![(generic!(T), generic!(V)), (generic!(V), generic!(U))]),
),
// Filters
raster_node!(graphene_core::raster::LuminanceNode<_>, params: [LuminanceCalculation]),
@ -146,25 +146,25 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
concrete!(Image),
concrete!(Image),
vec![
concrete!(DAffine2),
concrete!(f64),
concrete!(Option<DVec2>),
concrete!(f64),
concrete!(ImaginateSamplingMethod),
concrete!(f64),
concrete!(String),
concrete!(String),
concrete!(bool),
concrete!(f64),
concrete!(Option<Vec<u64>>),
concrete!(bool),
concrete!(f64),
concrete!(ImaginateMaskStartingFill),
concrete!(bool),
concrete!(bool),
concrete!(Option<std::sync::Arc<Image>>),
concrete!(f64),
concrete!(ImaginateStatus),
(concrete!(()), concrete!(DAffine2)),
(concrete!(()), concrete!(f64)),
(concrete!(()), concrete!(Option<DVec2>)),
(concrete!(()), concrete!(f64)),
(concrete!(()), concrete!(ImaginateSamplingMethod)),
(concrete!(()), concrete!(f64)),
(concrete!(()), concrete!(String)),
(concrete!(()), concrete!(String)),
(concrete!(()), concrete!(bool)),
(concrete!(()), concrete!(f64)),
(concrete!(()), concrete!(Option<Vec<u64>>)),
(concrete!(()), concrete!(bool)),
(concrete!(()), concrete!(f64)),
(concrete!(()), concrete!(ImaginateMaskStartingFill)),
(concrete!(()), concrete!(bool)),
(concrete!(()), concrete!(bool)),
(concrete!(()), concrete!(Option<std::sync::Arc<Image>>)),
(concrete!(()), concrete!(f64)),
(concrete!(()), concrete!(ImaginateStatus)),
],
),
),
@ -203,7 +203,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
let node: DynAnyNode<&Image, _, _> = DynAnyNode::new(ValueNode::new(new_image));
node.into_type_erased()
},
NodeIOTypes::new(concrete!(Image), concrete!(Image), vec![concrete!(u32), concrete!(f64)]),
NodeIOTypes::new(concrete!(Image), concrete!(Image), vec![(concrete!(()), concrete!(u32)), (concrete!(()), concrete!(f64))]),
),
//register_node!(graphene_std::memo::CacheNode<_>, input: Image, params: []),
(