Add graph type error diagnostics to the UI (#1535)

* Fontend input types

* Fix index of errors / types

* Bug fixes, styling improvements, and code review

* Improvements to the error box

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
0HyperCube 2023-12-29 08:38:45 +00:00 committed by GitHub
parent 96b5d7b520
commit 947a131a4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 566 additions and 170 deletions

View file

@ -6,7 +6,7 @@ use dyn_any::StaticType;
#[cfg(feature = "std")]
pub use std::borrow::Cow;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct NodeIOTypes {
pub input: Type,
pub output: Type,
@ -23,6 +23,16 @@ impl NodeIOTypes {
}
}
impl core::fmt::Debug for NodeIOTypes {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!(
"node({}) -> {}",
[&self.input].into_iter().chain(&self.parameters).map(|input| input.to_string()).collect::<Vec<_>>().join(", "),
self.output
))
}
}
#[macro_export]
macro_rules! concrete {
($type:ty) => {
@ -193,6 +203,13 @@ impl Type {
}
}
fn format_type(ty: &str) -> String {
ty.split('<')
.map(|path| path.split(',').map(|path| path.split("::").last().unwrap_or(path)).collect::<Vec<_>>().join(","))
.collect::<Vec<_>>()
.join("<")
}
impl core::fmt::Debug for Type {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
@ -200,7 +217,7 @@ impl core::fmt::Debug for Type {
#[cfg(feature = "type_id_logging")]
Self::Concrete(arg0) => write!(f, "Concrete({}, {:?})", arg0.name, arg0.id),
#[cfg(not(feature = "type_id_logging"))]
Self::Concrete(arg0) => write!(f, "Concrete({})", arg0.name),
Self::Concrete(arg0) => write!(f, "Concrete({})", format_type(&arg0.name)),
Self::Fn(arg0, arg1) => write!(f, "({arg0:?} -> {arg1:?})"),
Self::Future(arg0) => write!(f, "Future({arg0:?})"),
}
@ -211,7 +228,7 @@ impl std::fmt::Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Generic(name) => write!(f, "{name}"),
Type::Concrete(ty) => write!(f, "{}", ty.name),
Type::Concrete(ty) => write!(f, "{}", format_type(&ty.name)),
Type::Fn(input, output) => write!(f, "({input} -> {output})"),
Type::Future(ty) => write!(f, "Future<{ty}>"),
}

View file

@ -166,9 +166,32 @@ pub struct DocumentNode {
/// Used as a hash of the graph input where applicable. This ensures that protonodes that depend on the graph's input are always regenerated.
#[serde(default)]
pub world_state_hash: u64,
/// The path to this node as of when [`NodeNetwork::generate_node_paths`] was called.
/// For example if this node was ID 6 inside a node with ID 4 and with a [`DocumentNodeImplementation::Network`], the path would be [4, 6].
/// The path to this node and its inputs and outputs as of when [`NodeNetwork::generate_node_paths`] was called.
#[serde(skip)]
pub original_location: OriginalLocation,
}
/// Represents the original location of a node input/output when [`NodeNetwork::generate_node_paths`] was called, allowing the types and errors to be derived.
#[derive(Clone, Debug, PartialEq, Eq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Source {
pub node: Vec<NodeId>,
pub index: usize,
}
/// The path to this node and its inputs and outputs as of when [`NodeNetwork::generate_node_paths`] was called.
#[derive(Clone, Debug, PartialEq, Eq, DynAny, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct OriginalLocation {
/// The original location to the document node - e.g. [grandparent_id, parent_id, node_id].
pub path: Option<Vec<NodeId>>,
/// Each document input source maps to one protonode input (however one protonode input may come from several sources)
pub inputs_source: HashMap<Source, usize>,
/// A list of document sources for the node's output
pub outputs_source: HashMap<Source, usize>,
pub inputs_exposed: Vec<bool>,
/// Skipping inputs is useful for the manual composition thing - whereby a hidden `Footprint` input is added as the first input.
pub skip_inputs: usize,
}
impl Default for DocumentNode {
@ -183,14 +206,42 @@ impl Default for DocumentNode {
metadata: Default::default(),
skip_deduplication: Default::default(),
world_state_hash: Default::default(),
path: Default::default(),
original_location: OriginalLocation::default(),
}
}
}
impl Hash for OriginalLocation {
fn hash<H: Hasher>(&self, state: &mut H) {
self.path.hash(state);
self.inputs_source.iter().for_each(|val| val.hash(state));
self.outputs_source.iter().for_each(|val| val.hash(state));
self.inputs_exposed.hash(state);
self.skip_inputs.hash(state);
}
}
impl OriginalLocation {
pub fn inputs<'a>(&'a self, index: usize) -> impl Iterator<Item = Source> + 'a {
[(index >= self.skip_inputs).then(|| Source {
node: self.path.clone().unwrap_or_default(),
index: self.inputs_exposed.iter().take(index - self.skip_inputs).filter(|&&exposed| exposed).count(),
})]
.into_iter()
.flatten()
.chain(self.inputs_source.iter().filter(move |x| *x.1 == index).map(|(source, _)| source.clone()))
}
pub fn outputs<'a>(&'a self, index: usize) -> impl Iterator<Item = Source> + 'a {
[Source {
node: self.path.clone().unwrap_or_default(),
index,
}]
.into_iter()
.chain(self.outputs_source.iter().filter(move |x| *x.1 == index).map(|(source, _)| source.clone()))
}
}
impl DocumentNode {
/// Locate the input that is a [`NodeInput::Network`] at index `offset` and replace it with a [`NodeInput::Node`].
pub fn populate_first_network_input(&mut self, node_id: NodeId, output_index: usize, offset: usize, lambda: bool) {
pub fn populate_first_network_input(&mut self, node_id: NodeId, output_index: usize, offset: usize, lambda: bool, source: impl Iterator<Item = Source>, skip: usize) {
let (index, _) = self
.inputs
.iter()
@ -200,6 +251,10 @@ impl DocumentNode {
.unwrap_or_else(|| panic!("no network input found for {self:#?} and offset: {offset}"));
self.inputs[index] = NodeInput::Node { node_id, output_index, lambda };
let input_source = &mut self.original_location.inputs_source;
for source in source {
input_source.insert(source, index + self.original_location.skip_inputs - skip);
}
}
fn resolve_proto_node(mut self) -> ProtoNode {
@ -246,7 +301,7 @@ impl DocumentNode {
identifier: fqn,
input,
construction_args: args,
document_node_path: self.path.unwrap_or_default(),
original_location: self.original_location,
skip_deduplication: self.skip_deduplication,
world_state_hash: self.world_state_hash,
}
@ -762,10 +817,15 @@ impl NodeNetwork {
if let DocumentNodeImplementation::Network(network) = &mut node.implementation {
network.generate_node_paths(new_path.as_slice());
}
if node.path.is_some() {
if node.original_location.path.is_some() {
log::warn!("Attempting to overwrite node path");
} else {
node.path = Some(new_path);
node.original_location = OriginalLocation {
path: Some(new_path),
inputs_exposed: node.inputs.iter().map(|input| input.is_exposed()).collect(),
skip_inputs: if node.manual_composition.is_some() { 1 } else { 0 },
..Default::default()
}
}
}
}
@ -831,7 +891,6 @@ impl NodeNetwork {
/// Remove all nodes that contain [`DocumentNodeImplementation::Network`] by moving the nested nodes into the parent network.
pub fn flatten_with_fns(&mut self, node: NodeId, map_ids: impl Fn(NodeId, NodeId) -> NodeId + Copy, gen_id: impl Fn() -> NodeId + Copy) {
self.resolve_extract_nodes();
let Some((id, mut node)) = self.nodes.remove_entry(&node) else {
warn!("The node which was supposed to be flattened does not exist in the network, id {node} network {self:#?}");
return;
@ -850,7 +909,7 @@ impl NodeNetwork {
}
// replace value inputs with value nodes
for input in &mut node.inputs {
for input in node.inputs.iter_mut() {
// Skip inputs that are already value nodes
if node.implementation == DocumentNodeImplementation::Unresolved("graphene_core::value::ClonedNode".into()) {
break;
@ -860,20 +919,17 @@ impl NodeNetwork {
if let NodeInput::Value { tagged_value, exposed } = previous_input {
let value_node_id = gen_id();
let merged_node_id = map_ids(id, value_node_id);
let path = if let Some(mut new_path) = node.path.clone() {
new_path.push(value_node_id);
Some(new_path)
} else {
None
};
let mut original_location = node.original_location.clone();
if let Some(path) = &mut original_location.path {
path.push(value_node_id);
}
self.nodes.insert(
merged_node_id,
DocumentNode {
name: "Value".into(),
inputs: vec![NodeInput::Value { tagged_value, exposed }],
implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::ClonedNode".into()),
path,
original_location,
..Default::default()
},
);
@ -888,8 +944,6 @@ impl NodeNetwork {
}
if let DocumentNodeImplementation::Network(mut inner_network) = node.implementation {
// Resolve all extract nodes in the inner network
inner_network.resolve_extract_nodes();
// Connect all network inputs to either the parent network nodes, or newly created value nodes.
inner_network.map_ids(|inner_id| map_ids(id, inner_id));
let new_nodes = inner_network.nodes.keys().cloned().collect::<Vec<_>>();
@ -914,14 +968,15 @@ impl NodeNetwork {
"Document Nodes with a Network implementation should have the same number of inner network inputs as inputs declared on the Document Node"
);
// Match the document node input and the inputs of the inner network
for (document_input, network_input) in node.inputs.into_iter().zip(inner_network.inputs.iter()) {
for (input_index, (document_input, network_input)) in node.inputs.into_iter().zip(inner_network.inputs.iter()).enumerate() {
// Keep track of how many network inputs we have already connected for each node
let offset = network_offsets.entry(network_input).or_insert(0);
match document_input {
// If the input to self is a node, connect the corresponding output of the inner network to it
NodeInput::Node { node_id, output_index, lambda } => {
let network_input = self.nodes.get_mut(network_input).unwrap();
network_input.populate_first_network_input(node_id, output_index, *offset, lambda);
let skip = node.original_location.skip_inputs;
network_input.populate_first_network_input(node_id, output_index, *offset, lambda, node.original_location.inputs(input_index), skip);
}
NodeInput::Network(_) => {
*network_offsets.get_mut(network_input).unwrap() += 1;
@ -941,6 +996,13 @@ impl NodeNetwork {
self.replace_node_inputs(node_input(id, i, false), node_input(output.node_id, output.node_output_index, false));
self.replace_node_inputs(node_input(id, i, true), node_input(output.node_id, output.node_output_index, true));
if let Some(new_output_node) = self.nodes.get_mut(&output.node_id) {
for source in node.original_location.outputs(i) {
info!("{:?} {}", source, output.node_output_index);
new_output_node.original_location.outputs_source.insert(source, output.node_output_index);
}
}
self.replace_network_outputs(NodeOutput::new(id, i), output);
}
@ -960,9 +1022,15 @@ impl NodeNetwork {
if ident.name == "graphene_core::ops::IdentityNode" {
assert_eq!(node.inputs.len(), 1, "Id node has more than one input");
if let NodeInput::Node { node_id, output_index, .. } = node.inputs[0] {
if let Some(input_node) = self.nodes.get_mut(&node_id) {
for source in node.original_location.outputs(0) {
input_node.original_location.outputs_source.insert(source, output_index);
}
}
let input_node_id = node_id;
for output in self.nodes.values_mut() {
for input in &mut output.inputs {
for (index, input) in output.inputs.iter_mut().enumerate() {
if let NodeInput::Node {
node_id: output_node_id,
output_index: output_output_index,
@ -972,6 +1040,11 @@ impl NodeNetwork {
if *output_node_id == id {
*output_node_id = input_node_id;
*output_output_index = output_index;
let input_source = &mut output.original_location.inputs_source;
for source in node.original_location.inputs(index) {
input_source.insert(source, index + output.original_location.skip_inputs - node.original_location.skip_inputs);
}
}
}
}
@ -1300,7 +1373,14 @@ mod test {
identifier: "graphene_core::structural::ConsNode".into(),
input: ProtoNodeInput::ManualComposition(concrete!(u32)),
construction_args: ConstructionArgs::Nodes(vec![(NodeId(14), false)]),
document_node_path: vec![NodeId(1), NodeId(0)],
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(0)]),
inputs_source: [(Source { node: vec![NodeId(1)], index: 0 }, 1)].into(),
outputs_source: HashMap::new(),
inputs_exposed: vec![false, false],
skip_inputs: 0,
},
..Default::default()
},
),
@ -1310,7 +1390,13 @@ mod test {
identifier: "graphene_core::ops::AddPairNode".into(),
input: ProtoNodeInput::Node(NodeId(10), false),
construction_args: ConstructionArgs::Nodes(vec![]),
document_node_path: vec![NodeId(1), NodeId(1)],
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(1)]),
inputs_source: HashMap::new(),
outputs_source: [(Source { node: vec![NodeId(1)], index: 0 }, 0)].into(),
inputs_exposed: vec![true],
skip_inputs: 0,
},
..Default::default()
},
),
@ -1338,7 +1424,13 @@ mod test {
name: "Cons".into(),
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::node(NodeId(14), 0)],
implementation: DocumentNodeImplementation::Unresolved("graphene_core::structural::ConsNode".into()),
path: Some(vec![NodeId(1), NodeId(0)]),
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(0)]),
inputs_source: [(Source { node: vec![NodeId(1)], index: 0 }, 1)].into(),
outputs_source: HashMap::new(),
inputs_exposed: vec![false, false],
skip_inputs: 0,
},
..Default::default()
},
),
@ -1351,7 +1443,13 @@ mod test {
exposed: false,
}],
implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::ClonedNode".into()),
path: Some(vec![NodeId(1), NodeId(4)]),
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(4)]),
inputs_source: HashMap::new(),
outputs_source: HashMap::new(),
inputs_exposed: vec![false, false],
skip_inputs: 0,
},
..Default::default()
},
),
@ -1361,7 +1459,13 @@ mod test {
name: "Add".into(),
inputs: vec![NodeInput::node(NodeId(10), 0)],
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::AddPairNode".into()),
path: Some(vec![NodeId(1), NodeId(1)]),
original_location: OriginalLocation {
path: Some(vec![NodeId(1), NodeId(1)]),
inputs_source: HashMap::new(),
outputs_source: [(Source { node: vec![NodeId(1)], index: 0 }, 0)].into(),
inputs_exposed: vec![true],
skip_inputs: 0,
},
..Default::default()
},
),

View file

@ -1,16 +1,16 @@
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::ops::Deref;
use std::hash::Hash;
use crate::document::NodeId;
use crate::document::{value, InlineRust};
use crate::document::{NodeId, OriginalLocation};
use dyn_any::DynAny;
use graphene_core::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::Deref;
use std::pin::Pin;
pub type DynFuture<'n, T> = Pin<Box<dyn core::future::Future<Output = T> + 'n>>;
@ -205,7 +205,7 @@ pub struct ProtoNode {
pub construction_args: ConstructionArgs,
pub input: ProtoNodeInput,
pub identifier: ProtoNodeIdentifier,
pub document_node_path: Vec<NodeId>,
pub original_location: OriginalLocation,
pub skip_deduplication: bool,
// TODO: This is a hack, figure out a proper solution
/// Represents a global state on which the node depends.
@ -218,7 +218,7 @@ impl Default for ProtoNode {
identifier: ProtoNodeIdentifier::new("graphene_core::ops::IdentityNode"),
construction_args: ConstructionArgs::Value(value::TaggedValue::U32(0)),
input: ProtoNodeInput::None,
document_node_path: vec![],
original_location: OriginalLocation::default(),
skip_deduplication: false,
world_state_hash: 0,
}
@ -266,7 +266,7 @@ impl ProtoNode {
self.identifier.name.hash(&mut hasher);
self.construction_args.hash(&mut hasher);
if self.skip_deduplication {
self.document_node_path.hash(&mut hasher);
self.original_location.path.hash(&mut hasher);
}
self.world_state_hash.hash(&mut hasher);
std::mem::discriminant(&self.input).hash(&mut hasher);
@ -282,11 +282,19 @@ impl ProtoNode {
/// Construct a new [`ProtoNode`] with the specified construction args and a `ClonedNode` implementation.
pub fn value(value: ConstructionArgs, path: Vec<NodeId>) -> Self {
let inputs_exposed = match &value {
ConstructionArgs::Nodes(nodes) => nodes.len() + 1,
_ => 2,
};
Self {
identifier: ProtoNodeIdentifier::new("graphene_core::value::ClonedNode"),
construction_args: value,
input: ProtoNodeInput::None,
document_node_path: path,
original_location: OriginalLocation {
path: Some(path),
inputs_exposed: vec![false; inputs_exposed],
..Default::default()
},
skip_deduplication: false,
world_state_hash: 0,
}
@ -396,8 +404,10 @@ impl ProtoNetwork {
let input = input_node_id_proto.input.clone();
let mut path = input_node_id_proto.document_node_path.clone();
path.push(node_id);
let mut path = input_node_id_proto.original_location.path.clone();
if let Some(path) = &mut path {
path.push(node_id);
}
self.nodes.push((
compose_node_id,
@ -405,7 +415,7 @@ impl ProtoNetwork {
identifier: ProtoNodeIdentifier::new("graphene_core::structural::ComposeNode<_, _, _>"),
construction_args: ConstructionArgs::Nodes(vec![(input_node_id, false), (node_id, true)]),
input,
document_node_path: path,
original_location: OriginalLocation { path, ..Default::default() },
skip_deduplication: false,
world_state_hash: 0,
},
@ -544,6 +554,78 @@ impl ProtoNetwork {
Ok(())
}
}
#[derive(Clone, PartialEq)]
pub enum GraphErrorType {
NodeNotFound(NodeId),
InputNodeNotFound(NodeId),
UnexpectedGenerics { index: usize, parameters: Vec<Type> },
NoImplementations,
NoConstructor,
InvalidImplementations { parameters: String, error_inputs: Vec<Vec<(usize, (Type, Type))>> },
MultipleImplementations { parameters: String, valid: Vec<NodeIOTypes> },
}
impl core::fmt::Debug for GraphErrorType {
// TODO: format with the document graph context so the input index is the same as in the graph UI.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GraphErrorType::NodeNotFound(id) => write!(f, "Input node {id} is not present in the typing context"),
GraphErrorType::InputNodeNotFound(id) => write!(f, "Input node {id} is not present in the typing context"),
GraphErrorType::UnexpectedGenerics { index, parameters } => write!(f, "Generic parameters should not exist but found at {index}: {parameters:?}"),
GraphErrorType::NoImplementations => write!(f, "No implementations found"),
GraphErrorType::NoConstructor => write!(f, "No construct found for node"),
GraphErrorType::InvalidImplementations { parameters, error_inputs } => {
let ordinal = |x: usize| match x.to_string().as_str() {
x if x.ends_with('1') && !x.ends_with("11") => format!("{x}st"),
x if x.ends_with('2') && !x.ends_with("12") => format!("{x}nd"),
x if x.ends_with('3') && !x.ends_with("13") => format!("{x}rd"),
x => format!("{x}th parameter"),
};
let format_index = |index: usize| if index == 0 { "primary".to_string() } else { format!("{} parameter", ordinal(index - 1)) };
let format_error = |(index, (real, expected)): &(usize, (Type, Type))| format!("• The {} input expected {} but found {}", format_index(*index), expected, real);
let format_error_list = |errors: &Vec<(usize, (Type, Type))>| errors.iter().map(format_error).collect::<Vec<_>>().join("\n");
let errors = error_inputs.iter().map(format_error_list).collect::<Vec<_>>();
write!(
f,
"Node graph type error! If this just appeared while editing the graph,\n\
consider using undo to go back and trying another way to connect the nodes.\n\
\n\
No node implementation exists for type ({parameters}).\n\
\n\
Caused by{}:\n\
{}",
if errors.len() > 1 { " one of" } else { "" },
errors.join("\n")
)
}
GraphErrorType::MultipleImplementations { parameters, valid } => write!(f, "Multiple implementations found ({parameters}):\n{valid:#?}"),
}
}
}
#[derive(Clone, PartialEq)]
pub struct GraphError {
pub node_path: Vec<NodeId>,
pub identifier: Cow<'static, str>,
pub error: GraphErrorType,
}
impl GraphError {
pub fn new(node: &ProtoNode, text: impl Into<GraphErrorType>) -> Self {
Self {
node_path: node.original_location.path.clone().unwrap_or_default(),
identifier: node.identifier.name.clone(),
error: text.into(),
}
}
}
impl core::fmt::Debug for GraphError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NodeGraphError")
.field("path", &self.node_path.iter().map(|id| id.0).collect::<Vec<_>>())
.field("identifier", &self.identifier.to_string())
.field("error", &self.error)
.finish()
}
}
pub type GraphErrors = Vec<GraphError>;
/// The `TypingContext` is used to store the types of the nodes indexed by their stable node id.
#[derive(Default, Clone)]
@ -562,10 +644,10 @@ impl TypingContext {
}
}
/// Updates the `TypingContext` wtih a given proto network. This will infer the types of the nodes
/// Updates the `TypingContext` with a given proto network. This will infer the types of the nodes
/// and store them in the `inferred` field. The proto network has to be topologically sorted
/// and contain fully resolved stable node ids.
pub fn update(&mut self, network: &ProtoNetwork) -> Result<(), String> {
pub fn update(&mut self, network: &ProtoNetwork) -> Result<(), GraphErrors> {
for (id, node) in network.nodes.iter() {
self.infer(*id, node)?;
}
@ -583,12 +665,10 @@ impl TypingContext {
}
/// Returns the inferred types for a given node id.
pub fn infer(&mut self, node_id: NodeId, node: &ProtoNode) -> Result<NodeIOTypes, String> {
let identifier = node.identifier.name.clone();
pub fn infer(&mut self, node_id: NodeId, node: &ProtoNode) -> Result<NodeIOTypes, GraphErrors> {
// Return the inferred type if it is already known
if let Some(infered) = self.inferred.get(&node_id) {
return Ok(infered.clone());
if let Some(inferred) = self.inferred.get(&node_id) {
return Ok(inferred.clone());
}
let parameters = match node.construction_args {
@ -606,10 +686,10 @@ impl TypingContext {
.map(|(id, _)| {
self.inferred
.get(id)
.ok_or(format!("Inferring type of {node_id} depends on {id} which is not present in the typing context"))
.ok_or_else(|| vec![GraphError::new(node, GraphErrorType::NodeNotFound(*id))])
.map(|node| node.ty())
})
.collect::<Result<Vec<Type>, String>>()?,
.collect::<Result<Vec<Type>, GraphErrors>>()?,
ConstructionArgs::Inline(ref inline) => vec![inline.ty.clone()],
};
@ -618,23 +698,17 @@ impl TypingContext {
ProtoNodeInput::None => concrete!(()),
ProtoNodeInput::ManualComposition(ref ty) => ty.clone(),
ProtoNodeInput::Node(id, _) => {
let input = self
.inferred
.get(&id)
.ok_or(format!("Inferring type of {node_id} depends on {id} which is not present in the typing context"))?;
let input = self.inferred.get(&id).ok_or_else(|| vec![GraphError::new(node, GraphErrorType::InputNodeNotFound(id))])?;
input.output.clone()
}
};
let impls = self
.lookup
.get(&node.identifier)
.ok_or(format!("No implementations found for:\n\n{:?}\n\nOther implementations found:\n\n{:?}", node.identifier, self.lookup))?;
let impls = self.lookup.get(&node.identifier).ok_or_else(|| vec![GraphError::new(node, GraphErrorType::NoImplementations)])?;
if parameters.iter().any(|p| {
if let Some(index) = parameters.iter().position(|p| {
matches!(p,
Type::Fn(_, b) if matches!(b.as_ref(), Type::Generic(_)))
}) {
return Err(format!("Generic types are not supported in parameters: {:?} occurred in {:?}", parameters, node.identifier));
return Err(vec![GraphError::new(node, GraphErrorType::UnexpectedGenerics { index, parameters })]);
}
fn covariant(from: &Type, to: &Type) -> bool {
match (from, to) {
@ -651,7 +725,7 @@ impl TypingContext {
// List of all implementations that match the input and parameter types
let valid_output_types = impls
.keys()
.filter(|node_io| covariant(&input, &node_io.input) && parameters.iter().zip(node_io.parameters.iter()).all(|(p1, p2)| covariant(p1, p2) && covariant(p1, p2)))
.filter(|node_io| covariant(&input, &node_io.input) && parameters.iter().zip(node_io.parameters.iter()).all(|(p1, p2)| covariant(p1, p2)))
.collect::<Vec<_>>();
// Attempt to substitute generic types with concrete types and save the list of results
@ -677,10 +751,28 @@ impl TypingContext {
match valid_impls.as_slice() {
[] => {
dbg!(&self.inferred);
Err(format!(
"No implementations found for:\n\n{identifier}\n\nwith input:\n\n{input:?}\n\nand parameters:\n\n{parameters:?}\n\nOther Implementations found:\n\n{:?}",
impls.keys().collect::<Vec<_>>(),
))
let mut best_errors = usize::MAX;
let mut error_inputs = Vec::new();
for node_io in impls.keys() {
let current_errors = [&input]
.into_iter()
.chain(&parameters)
.cloned()
.zip([&node_io.input].into_iter().chain(&node_io.parameters).cloned())
.enumerate()
.filter(|(_, (p1, p2))| !covariant(p1, p2))
.map(|(index, ty)| (node.original_location.inputs(index).min_by_key(|s| s.node.len()).map(|s| s.index).unwrap_or(index), ty))
.collect::<Vec<_>>();
if current_errors.len() < best_errors {
best_errors = current_errors.len();
error_inputs.clear();
}
if current_errors.len() <= best_errors {
error_inputs.push(current_errors);
}
}
let parameters = [&input].into_iter().chain(&parameters).map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
Err(vec![GraphError::new(node, GraphErrorType::InvalidImplementations { parameters, error_inputs })])
}
[(org_nio, output)] => {
let node_io = NodeIOTypes::new(input, (*output).clone(), parameters);
@ -690,9 +782,12 @@ impl TypingContext {
self.constructor.insert(node_id, impls[org_nio]);
Ok(node_io)
}
_ => Err(format!(
"Multiple implementations found for {identifier} with input {input:?} and parameters {parameters:?} (valid types: {valid_output_types:?}"
)),
_ => {
let parameters = [&input].into_iter().chain(&parameters).map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
let valid = valid_output_types.into_iter().cloned().collect();
Err(vec![GraphError::new(node, GraphErrorType::MultipleImplementations { parameters, valid })])
}
}
}
}

View file

@ -1,15 +1,16 @@
use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::sync::Arc;
use crate::node_registry;
use dyn_any::StaticType;
use graph_craft::document::value::{TaggedValue, UpcastNode};
use graph_craft::document::NodeId;
use graph_craft::document::{NodeId, Source};
use graph_craft::graphene_compiler::Executor;
use graph_craft::proto::{ConstructionArgs, LocalFuture, NodeContainer, ProtoNetwork, ProtoNode, SharedNodeContainer, TypeErasedBox, TypingContext};
use graph_craft::proto::{ConstructionArgs, GraphError, LocalFuture, NodeContainer, ProtoNetwork, ProtoNode, SharedNodeContainer, TypeErasedBox, TypingContext};
use graph_craft::proto::{GraphErrorType, GraphErrors};
use graph_craft::Type;
use crate::node_registry;
use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::sync::Arc;
/// An executor of a node graph that does not require an online compilation server, and instead uses `Box<dyn ...>`.
pub struct DynamicExecutor {
@ -33,8 +34,15 @@ impl Default for DynamicExecutor {
}
}
#[derive(PartialEq, Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ResolvedDocumentNodeTypes {
pub inputs: HashMap<Source, Type>,
pub outputs: HashMap<Source, Type>,
}
impl DynamicExecutor {
pub async fn new(proto_network: ProtoNetwork) -> Result<Self, String> {
pub async fn new(proto_network: ProtoNetwork) -> Result<Self, GraphErrors> {
let mut typing_context = TypingContext::new(&node_registry::NODE_REGISTRY);
typing_context.update(&proto_network)?;
let output = proto_network.output;
@ -49,7 +57,7 @@ impl DynamicExecutor {
}
/// Updates the existing [`BorrowTree`] to reflect the new [`ProtoNetwork`], reusing nodes where possible.
pub async fn update(&mut self, proto_network: ProtoNetwork) -> Result<(), String> {
pub async fn update(&mut self, proto_network: ProtoNetwork) -> Result<(), GraphErrors> {
self.output = proto_network.output;
self.typing_context.update(&proto_network)?;
let mut orphans = self.tree.update(proto_network, &self.typing_context).await?;
@ -74,6 +82,22 @@ impl DynamicExecutor {
pub fn output_type(&self) -> Option<Type> {
self.typing_context.type_of(self.output).map(|node_io| node_io.output.clone())
}
pub fn document_node_types(&self) -> ResolvedDocumentNodeTypes {
let mut resolved_document_node_types = ResolvedDocumentNodeTypes::default();
for (source, &(protonode_id, protonode_index)) in self.tree.inputs_source_map() {
let Some(node_io) = self.typing_context.type_of(protonode_id) else { continue };
let Some(ty) = [&node_io.input].into_iter().chain(&node_io.parameters).nth(protonode_index) else {
continue;
};
resolved_document_node_types.inputs.insert(source.clone(), ty.clone());
}
for (source, &protonode_id) in self.tree.outputs_source_map() {
let Some(node_io) = self.typing_context.type_of(protonode_id) else { continue };
resolved_document_node_types.outputs.insert(source.clone(), node_io.output.clone());
}
resolved_document_node_types
}
}
impl<'a, I: StaticType + 'a> Executor<I, TaggedValue> for &'a DynamicExecutor {
@ -89,10 +113,14 @@ pub struct BorrowTree {
nodes: HashMap<NodeId, SharedNodeContainer>,
/// A hashmap from the document path to the protonode ID.
source_map: HashMap<Vec<NodeId>, NodeId>,
/// Each document input source maps to one protonode input (however one protonode input may come from several sources)
inputs_source_map: HashMap<Source, (NodeId, usize)>,
/// A mapping of document input sources to the (single) protonode output
outputs_source_map: HashMap<Source, NodeId>,
}
impl BorrowTree {
pub async fn new(proto_network: ProtoNetwork, typing_context: &TypingContext) -> Result<BorrowTree, String> {
pub async fn new(proto_network: ProtoNetwork, typing_context: &TypingContext) -> Result<BorrowTree, GraphErrors> {
let mut nodes = BorrowTree::default();
for (id, node) in proto_network.nodes {
nodes.push_node(id, node, typing_context).await?
@ -101,7 +129,7 @@ impl BorrowTree {
}
/// Pushes new nodes into the tree and return orphaned nodes
pub async fn update(&mut self, proto_network: ProtoNetwork, typing_context: &TypingContext) -> Result<Vec<NodeId>, String> {
pub async fn update(&mut self, proto_network: ProtoNetwork, typing_context: &TypingContext) -> Result<Vec<NodeId>, GraphErrors> {
let mut old_nodes: HashSet<_> = self.nodes.keys().copied().collect();
for (id, node) in proto_network.nodes {
if !self.nodes.contains_key(&id) {
@ -110,6 +138,8 @@ impl BorrowTree {
old_nodes.remove(&id);
}
self.source_map.retain(|_, nid| !old_nodes.contains(nid));
self.inputs_source_map.retain(|_, (nid, _)| !old_nodes.contains(nid));
self.outputs_source_map.retain(|_, nid| !old_nodes.contains(nid));
self.nodes.retain(|nid, _| !old_nodes.contains(nid));
Ok(old_nodes.into_iter().collect())
}
@ -152,18 +182,23 @@ impl BorrowTree {
}
/// Insert a new node into the borrow tree, calling the constructor function from `node_registry.rs`.
pub async fn push_node(&mut self, id: NodeId, proto_node: ProtoNode, typing_context: &TypingContext) -> Result<(), String> {
let ProtoNode {
construction_args,
identifier,
document_node_path,
..
} = proto_node;
self.source_map.insert(document_node_path, id);
pub async fn push_node(&mut self, id: NodeId, proto_node: ProtoNode, typing_context: &TypingContext) -> Result<(), GraphErrors> {
self.source_map.insert(proto_node.original_location.path.clone().unwrap_or_default(), id);
match construction_args {
let params = match &proto_node.construction_args {
ConstructionArgs::Nodes(nodes) => nodes.len() + 1,
_ => 2,
};
self.inputs_source_map
.extend((0..params).flat_map(|i| proto_node.original_location.inputs(i).map(move |source| (source, (id, i)))));
self.outputs_source_map.extend(proto_node.original_location.outputs(0).map(|source| (source, id)));
for x in proto_node.original_location.outputs_source.values() {
assert_eq!(*x, 0, "protonodes should refer to output index 0");
}
match &proto_node.construction_args {
ConstructionArgs::Value(value) => {
let upcasted = UpcastNode::new(value);
let upcasted = UpcastNode::new(value.to_owned());
let node = Box::new(upcasted) as TypeErasedBox<'_>;
let node = NodeContainer::new(node);
self.store_node(node, id);
@ -172,7 +207,7 @@ impl BorrowTree {
ConstructionArgs::Nodes(ids) => {
let ids: Vec<_> = ids.iter().map(|(id, _)| *id).collect();
let construction_nodes = self.node_deps(&ids);
let constructor = typing_context.constructor(id).ok_or(format!("No constructor found for node {identifier:?}"))?;
let constructor = typing_context.constructor(id).ok_or_else(|| vec![GraphError::new(&proto_node, GraphErrorType::NoConstructor)])?;
let node = constructor(construction_nodes).await;
let node = NodeContainer::new(node);
self.store_node(node, id);
@ -180,6 +215,14 @@ impl BorrowTree {
};
Ok(())
}
pub fn inputs_source_map(&self) -> impl Iterator<Item = (&Source, &(NodeId, usize))> {
self.inputs_source_map.iter()
}
pub fn outputs_source_map(&self) -> impl Iterator<Item = (&Source, &NodeId)> {
self.outputs_source_map.iter()
}
}
#[cfg(test)]

View file

@ -73,7 +73,7 @@ mod tests {
let compiler = Compiler {};
let protograph = compiler.compile_single(network).expect("Graph should be generated");
let exec = block_on(DynamicExecutor::new(protograph)).unwrap_or_else(|e| panic!("Failed to create executor: {e}"));
let exec = block_on(DynamicExecutor::new(protograph)).unwrap_or_else(|e| panic!("Failed to create executor: {e:?}"));
let result = block_on((&exec).execute(32_u32)).unwrap();
assert_eq!(result, TaggedValue::U32(33));