Fix crash when a cycle is introduced into the graph (#1427)

* Changing return of topological_sort to Result and propagating error

* Simplifying "compile()" method, adding "expect()" to tests.

* Removing Result type from "map_gpu()"

* Reverting to assertion and removing unnecessary returns
This commit is contained in:
Vlad Rakhmanin 2023-09-30 11:07:29 +01:00 committed by GitHub
parent 7e3469fa3f
commit b2397b06c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 40 deletions

View file

@ -8,7 +8,7 @@ use crate::proto::{LocalFuture, ProtoNetwork};
pub struct Compiler {}
impl Compiler {
pub fn compile(&self, mut network: NodeNetwork) -> impl Iterator<Item = ProtoNetwork> {
pub fn compile(&self, mut network: NodeNetwork) -> Result<impl Iterator<Item = ProtoNetwork>, String> {
println!("flattening");
let node_ids = network.nodes.keys().copied().collect::<Vec<_>>();
for id in node_ids {
@ -17,15 +17,20 @@ impl Compiler {
network.remove_redundant_id_nodes();
network.remove_dead_nodes();
let proto_networks = network.into_proto_networks();
proto_networks.map(move |mut proto_network| {
proto_network.resolve_inputs();
proto_network.generate_stable_node_ids();
proto_network
})
let proto_networks_result: Vec<ProtoNetwork> = proto_networks
.map(move |mut proto_network| {
proto_network.resolve_inputs()?;
proto_network.generate_stable_node_ids();
Ok(proto_network)
})
.collect::<Result<Vec<ProtoNetwork>, String>>()?;
Ok(proto_networks_result.into_iter())
}
pub fn compile_single(&self, network: NodeNetwork) -> Result<ProtoNetwork, String> {
assert_eq!(network.outputs.len(), 1, "Graph with multiple outputs not yet handled");
let Some(proto_network) = self.compile(network).next() else {
let Some(proto_network) = self.compile(network)?.next() else {
return Err("Failed to convert graph into proto graph".to_string());
};
Ok(proto_network)

View file

@ -336,9 +336,9 @@ impl ProtoNetwork {
edges
}
pub fn resolve_inputs(&mut self) {
pub fn resolve_inputs(&mut self) -> Result<(), String> {
// Perform topological sort once
self.reorder_ids();
self.reorder_ids()?;
let max_id = self.nodes.len() as NodeId - 1;
@ -370,7 +370,8 @@ impl ProtoNetwork {
self.replace_node_id(&outwards_edges, node_id, compose_node_id, true);
}
}
self.reorder_ids();
self.reorder_ids()?;
Ok(())
}
fn replace_node_id(&mut self, outwards_edges: &HashMap<u64, Vec<u64>>, node_id: u64, compose_node_id: u64, skip_lambdas: bool) {
@ -392,33 +393,35 @@ impl ProtoNetwork {
}
});
}
// Based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
// This approach excludes nodes that are not connected
pub fn topological_sort(&self) -> Vec<NodeId> {
pub fn topological_sort(&self) -> Result<Vec<NodeId>, String> {
let mut sorted = Vec::new();
let inwards_edges = self.collect_inwards_edges();
fn visit(node_id: NodeId, temp_marks: &mut HashSet<NodeId>, sorted: &mut Vec<NodeId>, inwards_edges: &HashMap<NodeId, Vec<NodeId>>, network: &ProtoNetwork) {
fn visit(node_id: NodeId, temp_marks: &mut HashSet<NodeId>, sorted: &mut Vec<NodeId>, inwards_edges: &HashMap<NodeId, Vec<NodeId>>, network: &ProtoNetwork) -> Result<(), String> {
if sorted.contains(&node_id) {
return;
return Ok(());
};
if temp_marks.contains(&node_id) {
panic!("Cycle detected {:#?}, {:#?}", &inwards_edges, &network);
return Err(format!("Cycle detected {:#?}, {:#?}", &inwards_edges, &network));
}
if let Some(dependencies) = inwards_edges.get(&node_id) {
temp_marks.insert(node_id);
for &dependant in dependencies {
visit(dependant, temp_marks, sorted, inwards_edges, network);
visit(dependant, temp_marks, sorted, inwards_edges, network)?;
}
temp_marks.remove(&node_id);
}
sorted.push(node_id);
Ok(())
}
assert!(self.nodes.iter().any(|(id, _)| *id == self.output), "Output id {} does not exist", self.output);
visit(self.output, &mut HashSet::new(), &mut sorted, &inwards_edges, self);
sorted
if !self.nodes.iter().any(|(id, _)| *id == self.output) {
return Err(format!("Output id {} does not exist", self.output));
}
visit(self.output, &mut HashSet::new(), &mut sorted, &inwards_edges, self)?;
Ok(sorted)
}
fn is_topologically_sorted(&self) -> bool {
@ -465,8 +468,8 @@ impl ProtoNetwork {
sorted
}*/
fn reorder_ids(&mut self) {
let order = self.topological_sort();
fn reorder_ids(&mut self) -> Result<(), String> {
let order = self.topological_sort()?;
// Map of node ids to their current index in the nodes vector
let current_positions: HashMap<_, _> = self.nodes.iter().enumerate().map(|(pos, (id, _))| (*id, pos)).collect();
@ -492,6 +495,7 @@ impl ProtoNetwork {
self.output = *new_positions.get(&self.output).unwrap();
assert_eq!(order.len(), self.nodes.len());
Ok(())
}
}
@ -687,17 +691,24 @@ mod test {
#[test]
fn topological_sort() {
let construction_network = test_network();
let sorted = construction_network.topological_sort();
let sorted = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network.");
println!("{:#?}", sorted);
assert_eq!(sorted, vec![14, 10, 11, 1]);
}
#[test]
fn topological_sort_with_cycles() {
let construction_network = test_network_with_cycles();
let sorted = construction_network.topological_sort();
assert!(sorted.is_err())
}
#[test]
fn id_reordering() {
let mut construction_network = test_network();
construction_network.reorder_ids();
let sorted = construction_network.topological_sort();
construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network.");
let sorted = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network.");
println!("nodes: {:#?}", construction_network.nodes);
assert_eq!(sorted, vec![0, 1, 2, 3]);
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect();
@ -710,9 +721,9 @@ mod test {
#[test]
fn id_reordering_idempotent() {
let mut construction_network = test_network();
construction_network.reorder_ids();
construction_network.reorder_ids();
let sorted = construction_network.topological_sort();
construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network.");
construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network.");
let sorted = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network.");
assert_eq!(sorted, vec![0, 1, 2, 3]);
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect();
println!("{:#?}", ids);
@ -723,7 +734,7 @@ mod test {
#[test]
fn input_resolution() {
let mut construction_network = test_network();
construction_network.resolve_inputs();
construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network.");
println!("{:#?}", construction_network);
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value");
assert_eq!(construction_network.nodes.len(), 6);
@ -733,7 +744,7 @@ mod test {
#[test]
fn stable_node_id_generation() {
let mut construction_network = test_network();
construction_network.resolve_inputs();
construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network.");
construction_network.generate_stable_node_ids();
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value");
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect();
@ -810,4 +821,35 @@ mod test {
.collect(),
}
}
fn test_network_with_cycles() -> ProtoNetwork {
ProtoNetwork {
inputs: vec![1],
output: 1,
nodes: [
(
1,
ProtoNode {
identifier: "id".into(),
input: ProtoNodeInput::Node(2, false),
construction_args: ConstructionArgs::Nodes(vec![]),
document_node_path: vec![],
skip_deduplication: false,
},
),
(
2,
ProtoNode {
identifier: "id".into(),
input: ProtoNodeInput::Node(1, false),
construction_args: ConstructionArgs::Nodes(vec![]),
document_node_path: vec![],
skip_deduplication: false,
},
),
]
.into_iter()
.collect(),
}
}
}

View file

@ -26,10 +26,10 @@ pub struct GpuCompiler<TypingContext, ShaderIO> {
// TODO: Move to graph-craft
#[node_macro::node_fn(GpuCompiler)]
async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> compilation_client::Shader {
async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> Result<compilation_client::Shader, String> {
let compiler = graph_craft::graphene_compiler::Compiler {};
let DocumentNodeImplementation::Network(ref network) = node.implementation else { panic!() };
let proto_networks: Vec<_> = compiler.compile(network.clone()).collect();
let proto_networks: Vec<_> = compiler.compile(network.clone())?.collect();
for network in proto_networks.iter() {
typing_context.update(network).expect("Failed to type check network");
@ -43,7 +43,7 @@ async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingConte
.collect();
let output_types = proto_networks.iter().map(|network| typing_context.type_of(network.output).unwrap().output.clone()).collect();
compilation_client::compile(proto_networks, input_types, output_types, io).await.unwrap()
Ok(compilation_client::compile(proto_networks, input_types, output_types, io).await.unwrap())
}
pub struct MapGpuNode<Node, EditorApi> {
@ -97,7 +97,10 @@ async fn map_gpu<'a: 'input>(image: ImageFrame<Color>, node: DocumentNode, edito
self.cache.borrow().get(&node.name).unwrap().clone()
} else {
let name = node.name.clone();
let compute_pass_descriptor = create_compute_pass_descriptor(node, &image, executor, quantization).await;
let Ok(compute_pass_descriptor) = create_compute_pass_descriptor(node, &image, executor, quantization).await else {
log::error!("Error creating compute pass descriptor in 'map_gpu()");
return ImageFrame::empty();
};
self.cache.borrow_mut().insert(name, compute_pass_descriptor.clone());
log::error!("created compute pass");
compute_pass_descriptor
@ -156,7 +159,7 @@ async fn create_compute_pass_descriptor<T: Clone + Pixel + StaticTypeSized>(
image: &ImageFrame<T>,
executor: &&WgpuExecutor,
quantization: QuantizationChannels,
) -> ComputePass<WgpuExecutor> {
) -> Result<ComputePass<WgpuExecutor>, String> {
let compiler = graph_craft::graphene_compiler::Compiler {};
let inner_network = NodeNetwork::value_network(node);
@ -246,7 +249,7 @@ async fn create_compute_pass_descriptor<T: Clone + Pixel + StaticTypeSized>(
..Default::default()
};
log::debug!("compiling network");
let proto_networks = compiler.compile(network.clone()).collect();
let proto_networks = compiler.compile(network.clone())?.collect();
log::debug!("compiling shader");
let shader = compilation_client::compile(
proto_networks,
@ -344,10 +347,10 @@ async fn create_compute_pass_descriptor<T: Clone + Pixel + StaticTypeSized>(
};
log::debug!("created pipeline");
ComputePass {
Ok(ComputePass {
pipeline_layout: pipeline,
readback_buffer: Some(readback_buffer.clone()),
}
})
}
/*
#[node_macro::node_fn(MapGpuNode)]
@ -417,7 +420,7 @@ pub struct BlendGpuImageNode<Background, B, O> {
async fn blend_gpu_image(foreground: ImageFrame<Color>, background: ImageFrame<Color>, blend_mode: BlendMode, opacity: f32) -> ImageFrame<Color> {
let foreground_size = DVec2::new(foreground.image.width as f64, foreground.image.height as f64);
let background_size = DVec2::new(background.image.width as f64, background.image.height as f64);
// Transforms a point from the background image to the forground image
// Transforms a point from the background image to the foreground image
let bg_to_fg = DAffine2::from_scale(foreground_size) * foreground.transform.inverse() * background.transform * DAffine2::from_scale(1. / background_size);
let transform_matrix: Mat2 = bg_to_fg.matrix2.as_mat2();
@ -464,7 +467,11 @@ async fn blend_gpu_image(foreground: ImageFrame<Color>, background: ImageFrame<C
..Default::default()
};
log::debug!("compiling network");
let proto_networks = compiler.compile(network.clone()).collect();
let Ok(proto_networks_result) = compiler.compile(network.clone()) else {
log::error!("Error compiling network in 'blend_gpu_image()");
return ImageFrame::empty();
};
let proto_networks = proto_networks_result.collect();
log::debug!("compiling shader");
let shader = compilation_client::compile(