Add type checking to the node graph (#1025)

* Implement type inference

Add type hints to node trait

Add type annotation infrastructure

Refactor type ascription infrastructure

Run cargo fix

Insert infer types stub

Remove types from node identifier

* Implement covariance

* Disable rejection of generic inputs + parameters

* Fix lints

* Extend type checking to cover Network inputs

* Implement generic specialization

* Relax covariance rules

* Fix type annotations for TypErasedComposeNode

* Fix type checking errors

* Keep connection information during node resolution
* Fix TypeDescriptor PartialEq implementation

* Apply review suggestions

* Add documentation to type inference

* Add Imaginate node to document node types

* Fix whitespace in macros

* Add types to imaginate node

* Fix type declaration for imaginate node + add console logging

* Use fully qualified type names as fallback during comparison

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
Dennis Kobert 2023-02-15 23:31:30 +01:00 committed by Keavon Chambers
parent a64c856ec4
commit 5dab7de68d
25 changed files with 1365 additions and 1008 deletions

View file

@ -1,13 +1,15 @@
use crate::document::value::TaggedValue;
use crate::generic;
use crate::proto::{ConstructionArgs, NodeIdentifier, ProtoNetwork, ProtoNode, ProtoNodeInput, Type};
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
use graphene_core::{NodeIdentifier, Type};
use dyn_any::{DynAny, StaticType};
use glam::IVec2;
use graphene_core::TypeDescriptor;
use rand_chacha::{
rand_core::{RngCore, SeedableRng},
ChaCha20Rng,
};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::sync::Mutex;
@ -59,7 +61,7 @@ impl DocumentNode {
.inputs
.iter()
.enumerate()
.filter(|(_, input)| matches!(input, NodeInput::Network))
.filter(|(_, input)| matches!(input, NodeInput::Network(_)))
.nth(offset)
.expect("no network input");
@ -80,9 +82,9 @@ impl DocumentNode {
assert_eq!(output_index, 0, "Outputs should be flattened before converting to protonode.");
(ProtoNodeInput::Node(node_id), ConstructionArgs::Nodes(vec![]))
}
NodeInput::Network => (ProtoNodeInput::Network, ConstructionArgs::Nodes(vec![])),
NodeInput::Network(ty) => (ProtoNodeInput::Network(ty), ConstructionArgs::Nodes(vec![])),
};
assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::Network)), "recieved non resolved parameter");
assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::Network(_))), "recieved non resolved parameter");
assert!(
!self.inputs.iter().any(|input| matches!(input, NodeInput::Value { .. })),
"recieved value as parameter. inupts: {:#?}, construction_args: {:#?}",
@ -129,12 +131,12 @@ impl DocumentNode {
}
}
#[derive(Clone, Debug, specta::Type)]
#[derive(Debug, Clone, PartialEq, Hash, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum NodeInput {
Node { node_id: NodeId, output_index: usize },
Value { tagged_value: value::TaggedValue, exposed: bool },
Network,
Network(Type),
}
impl NodeInput {
@ -153,17 +155,14 @@ impl NodeInput {
match self {
NodeInput::Node { .. } => true,
NodeInput::Value { exposed, .. } => *exposed,
NodeInput::Network => false,
NodeInput::Network(_) => false,
}
}
}
impl PartialEq for NodeInput {
fn eq(&self, other: &Self) -> bool {
match (&self, &other) {
(Self::Node { node_id: n0, output_index: o0 }, Self::Node { node_id: n1, output_index: o1 }) => n0 == n1 && o0 == o1,
(Self::Value { tagged_value: v1, .. }, Self::Value { tagged_value: v2, .. }) => v1 == v2,
_ => core::mem::discriminant(self) == core::mem::discriminant(other),
pub fn ty(&self) -> Type {
match self {
NodeInput::Node { .. } => unreachable!("ty() called on NodeInput::Node"),
NodeInput::Value { tagged_value, .. } => tagged_value.ty(),
NodeInput::Network(ty) => ty.clone(),
}
}
}
@ -357,7 +356,7 @@ impl NodeNetwork {
.unwrap_or_else(|| panic!("The node which was supposed to be flattened does not exist in the network, id {} network {:#?}", node, self));
if self.disabled.contains(&id) {
node.implementation = DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")]));
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into());
node.inputs.drain(1..);
self.nodes.insert(id, node);
return;
@ -394,7 +393,7 @@ impl NodeNetwork {
let value_node = DocumentNode {
name,
inputs: vec![NodeInput::Value { tagged_value, exposed }],
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::value::ValueNode", &[generic!("T")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into()),
metadata: DocumentNodeMetadata::default(),
};
assert!(!self.nodes.contains_key(&new_id));
@ -402,7 +401,7 @@ impl NodeNetwork {
let network_input = self.nodes.get_mut(network_input).unwrap();
network_input.populate_first_network_input(new_id, 0, *offset);
}
NodeInput::Network => {
NodeInput::Network(_) => {
*network_offsets.get_mut(network_input).unwrap() += 1;
if let Some(index) = self.inputs.iter().position(|i| *i == id) {
self.inputs[index] = *network_input;
@ -410,7 +409,7 @@ impl NodeNetwork {
}
}
}
node.implementation = DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")]));
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into());
node.inputs = inner_network
.outputs
.iter()
@ -419,6 +418,7 @@ impl NodeNetwork {
output_index: node_output_index,
})
.collect();
for node_id in new_nodes {
self.flatten_with_fns(node_id, map_ids, gen_id);
}
@ -456,8 +456,8 @@ impl NodeNetwork {
0,
DocumentNode {
name: "Input".into(),
inputs: vec![NodeInput::Network],
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")])),
inputs: vec![NodeInput::Network(concrete!(u32))],
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
metadata: DocumentNodeMetadata { position: (8, 4).into() },
},
),
@ -466,7 +466,7 @@ impl NodeNetwork {
DocumentNode {
name: "Output".into(),
inputs: vec![NodeInput::node(output_node_id, 0)],
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
metadata: DocumentNodeMetadata { position: (output_offset, 4).into() },
},
),
@ -553,7 +553,8 @@ impl NodeNetwork {
#[cfg(test)]
mod test {
use super::*;
use crate::proto::{ConstructionArgs, NodeIdentifier, ProtoNetwork, ProtoNode, ProtoNodeInput};
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
use graphene_core::NodeIdentifier;
fn gen_node_id() -> NodeId {
static mut NODE_ID: NodeId = 3;
@ -572,9 +573,9 @@ mod test {
0,
DocumentNode {
name: "Cons".into(),
inputs: vec![NodeInput::Network, NodeInput::Network],
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::Network(concrete!(u32))],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::structural::ConsNode".into()),
},
),
(
@ -583,7 +584,7 @@ mod test {
name: "Add".into(),
inputs: vec![NodeInput::node(0, 0)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode", &[generic!("T"), generic!("U")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::AddNode".into()),
},
),
]
@ -605,9 +606,9 @@ mod test {
1,
DocumentNode {
name: "Cons".into(),
inputs: vec![NodeInput::Network, NodeInput::Network],
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::Network(concrete!(u32))],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::structural::ConsNode".into()),
},
),
(
@ -616,7 +617,7 @@ mod test {
name: "Add".into(),
inputs: vec![NodeInput::node(1, 0)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode", &[generic!("T"), generic!("U")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::AddNode".into()),
},
),
]
@ -637,7 +638,7 @@ mod test {
DocumentNode {
name: "Inc".into(),
inputs: vec![
NodeInput::Network,
NodeInput::Network(concrete!(u32)),
NodeInput::Value {
tagged_value: value::TaggedValue::U32(2),
exposed: false,
@ -663,15 +664,15 @@ mod test {
fn resolve_proto_node_add() {
let document_node = DocumentNode {
name: "Cons".into(),
inputs: vec![NodeInput::Network, NodeInput::node(0, 0)],
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::node(0, 0)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::structural::ConsNode".into()),
};
let proto_node = document_node.resolve_proto_node();
let reference = ProtoNode {
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")]),
input: ProtoNodeInput::Network,
identifier: "graphene_core::structural::ConsNode".into(),
input: ProtoNodeInput::Network(concrete!(u32)),
construction_args: ConstructionArgs::Nodes(vec![0]),
};
assert_eq!(proto_node, reference);
@ -686,7 +687,7 @@ mod test {
(
1,
ProtoNode {
identifier: NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")]),
identifier: "graphene_core::ops::IdNode".into(),
input: ProtoNodeInput::Node(11),
construction_args: ConstructionArgs::Nodes(vec![]),
},
@ -694,15 +695,15 @@ mod test {
(
10,
ProtoNode {
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")]),
input: ProtoNodeInput::Network,
identifier: "graphene_core::structural::ConsNode".into(),
input: ProtoNodeInput::Network(concrete!(u32)),
construction_args: ConstructionArgs::Nodes(vec![14]),
},
),
(
11,
ProtoNode {
identifier: NodeIdentifier::new("graphene_core::ops::AddNode", &[generic!("T"), generic!("U")]),
identifier: "graphene_core::ops::AddNode".into(),
input: ProtoNodeInput::Node(10),
construction_args: ConstructionArgs::Nodes(vec![]),
},
@ -731,16 +732,16 @@ mod test {
name: "Inc".into(),
inputs: vec![NodeInput::node(11, 0)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
},
),
(
10,
DocumentNode {
name: "Cons".into(),
inputs: vec![NodeInput::Network, NodeInput::node(14, 0)],
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::node(14, 0)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode", &[generic!("T"), generic!("U")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::structural::ConsNode".into()),
},
),
(
@ -752,7 +753,7 @@ mod test {
exposed: false,
}],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::value::ValueNode", &[generic!("T")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into()),
},
),
(
@ -761,7 +762,7 @@ mod test {
name: "Add".into(),
inputs: vec![NodeInput::node(10, 0)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode", &[generic!("T"), generic!("U")])),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::AddNode".into()),
},
),
]
@ -780,18 +781,18 @@ mod test {
1,
DocumentNode {
name: "Identity 1".into(),
inputs: vec![NodeInput::Network],
inputs: vec![NodeInput::Network(concrete!(u32))],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")])),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")),
},
),
(
2,
DocumentNode {
name: "Identity 2".into(),
inputs: vec![NodeInput::Network],
inputs: vec![NodeInput::Network(concrete!(u32))],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")])),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")),
},
),
]
@ -821,7 +822,7 @@ mod test {
name: "Result".into(),
inputs: vec![result_node_input],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode", &[generic!("T")])),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")),
},
),
]