mirror of
https://github.com/GraphiteEditor/Graphite.git
synced 2025-08-03 13:02:20 +00:00
Fix crash when a cycle is introduced into the graph (#1427)
* Changing return of topological_sort to Result and propagating error * Simplifying "compile()" method, adding "expect()" to tests. * Removing Result type from "map_gpu()" * Reverting to assertion and removing unnecessary returns
This commit is contained in:
parent
7e3469fa3f
commit
b2397b06c6
3 changed files with 94 additions and 40 deletions
|
@ -8,7 +8,7 @@ use crate::proto::{LocalFuture, ProtoNetwork};
|
|||
pub struct Compiler {}
|
||||
|
||||
impl Compiler {
|
||||
pub fn compile(&self, mut network: NodeNetwork) -> impl Iterator<Item = ProtoNetwork> {
|
||||
pub fn compile(&self, mut network: NodeNetwork) -> Result<impl Iterator<Item = ProtoNetwork>, String> {
|
||||
println!("flattening");
|
||||
let node_ids = network.nodes.keys().copied().collect::<Vec<_>>();
|
||||
for id in node_ids {
|
||||
|
@ -17,15 +17,20 @@ impl Compiler {
|
|||
network.remove_redundant_id_nodes();
|
||||
network.remove_dead_nodes();
|
||||
let proto_networks = network.into_proto_networks();
|
||||
proto_networks.map(move |mut proto_network| {
|
||||
proto_network.resolve_inputs();
|
||||
proto_network.generate_stable_node_ids();
|
||||
proto_network
|
||||
})
|
||||
|
||||
let proto_networks_result: Vec<ProtoNetwork> = proto_networks
|
||||
.map(move |mut proto_network| {
|
||||
proto_network.resolve_inputs()?;
|
||||
proto_network.generate_stable_node_ids();
|
||||
Ok(proto_network)
|
||||
})
|
||||
.collect::<Result<Vec<ProtoNetwork>, String>>()?;
|
||||
|
||||
Ok(proto_networks_result.into_iter())
|
||||
}
|
||||
pub fn compile_single(&self, network: NodeNetwork) -> Result<ProtoNetwork, String> {
|
||||
assert_eq!(network.outputs.len(), 1, "Graph with multiple outputs not yet handled");
|
||||
let Some(proto_network) = self.compile(network).next() else {
|
||||
let Some(proto_network) = self.compile(network)?.next() else {
|
||||
return Err("Failed to convert graph into proto graph".to_string());
|
||||
};
|
||||
Ok(proto_network)
|
||||
|
|
|
@ -336,9 +336,9 @@ impl ProtoNetwork {
|
|||
edges
|
||||
}
|
||||
|
||||
pub fn resolve_inputs(&mut self) {
|
||||
pub fn resolve_inputs(&mut self) -> Result<(), String> {
|
||||
// Perform topological sort once
|
||||
self.reorder_ids();
|
||||
self.reorder_ids()?;
|
||||
|
||||
let max_id = self.nodes.len() as NodeId - 1;
|
||||
|
||||
|
@ -370,7 +370,8 @@ impl ProtoNetwork {
|
|||
self.replace_node_id(&outwards_edges, node_id, compose_node_id, true);
|
||||
}
|
||||
}
|
||||
self.reorder_ids();
|
||||
self.reorder_ids()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn replace_node_id(&mut self, outwards_edges: &HashMap<u64, Vec<u64>>, node_id: u64, compose_node_id: u64, skip_lambdas: bool) {
|
||||
|
@ -392,33 +393,35 @@ impl ProtoNetwork {
|
|||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
|
||||
// This approach excludes nodes that are not connected
|
||||
pub fn topological_sort(&self) -> Vec<NodeId> {
|
||||
pub fn topological_sort(&self) -> Result<Vec<NodeId>, String> {
|
||||
let mut sorted = Vec::new();
|
||||
let inwards_edges = self.collect_inwards_edges();
|
||||
fn visit(node_id: NodeId, temp_marks: &mut HashSet<NodeId>, sorted: &mut Vec<NodeId>, inwards_edges: &HashMap<NodeId, Vec<NodeId>>, network: &ProtoNetwork) {
|
||||
fn visit(node_id: NodeId, temp_marks: &mut HashSet<NodeId>, sorted: &mut Vec<NodeId>, inwards_edges: &HashMap<NodeId, Vec<NodeId>>, network: &ProtoNetwork) -> Result<(), String> {
|
||||
if sorted.contains(&node_id) {
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
if temp_marks.contains(&node_id) {
|
||||
panic!("Cycle detected {:#?}, {:#?}", &inwards_edges, &network);
|
||||
return Err(format!("Cycle detected {:#?}, {:#?}", &inwards_edges, &network));
|
||||
}
|
||||
|
||||
if let Some(dependencies) = inwards_edges.get(&node_id) {
|
||||
temp_marks.insert(node_id);
|
||||
for &dependant in dependencies {
|
||||
visit(dependant, temp_marks, sorted, inwards_edges, network);
|
||||
visit(dependant, temp_marks, sorted, inwards_edges, network)?;
|
||||
}
|
||||
temp_marks.remove(&node_id);
|
||||
}
|
||||
sorted.push(node_id);
|
||||
Ok(())
|
||||
}
|
||||
assert!(self.nodes.iter().any(|(id, _)| *id == self.output), "Output id {} does not exist", self.output);
|
||||
visit(self.output, &mut HashSet::new(), &mut sorted, &inwards_edges, self);
|
||||
|
||||
sorted
|
||||
if !self.nodes.iter().any(|(id, _)| *id == self.output) {
|
||||
return Err(format!("Output id {} does not exist", self.output));
|
||||
}
|
||||
visit(self.output, &mut HashSet::new(), &mut sorted, &inwards_edges, self)?;
|
||||
Ok(sorted)
|
||||
}
|
||||
|
||||
fn is_topologically_sorted(&self) -> bool {
|
||||
|
@ -465,8 +468,8 @@ impl ProtoNetwork {
|
|||
sorted
|
||||
}*/
|
||||
|
||||
fn reorder_ids(&mut self) {
|
||||
let order = self.topological_sort();
|
||||
fn reorder_ids(&mut self) -> Result<(), String> {
|
||||
let order = self.topological_sort()?;
|
||||
|
||||
// Map of node ids to their current index in the nodes vector
|
||||
let current_positions: HashMap<_, _> = self.nodes.iter().enumerate().map(|(pos, (id, _))| (*id, pos)).collect();
|
||||
|
@ -492,6 +495,7 @@ impl ProtoNetwork {
|
|||
self.output = *new_positions.get(&self.output).unwrap();
|
||||
|
||||
assert_eq!(order.len(), self.nodes.len());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -687,17 +691,24 @@ mod test {
|
|||
#[test]
|
||||
fn topological_sort() {
|
||||
let construction_network = test_network();
|
||||
let sorted = construction_network.topological_sort();
|
||||
|
||||
let sorted = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network.");
|
||||
println!("{:#?}", sorted);
|
||||
assert_eq!(sorted, vec![14, 10, 11, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn topological_sort_with_cycles() {
|
||||
let construction_network = test_network_with_cycles();
|
||||
let sorted = construction_network.topological_sort();
|
||||
|
||||
assert!(sorted.is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_reordering() {
|
||||
let mut construction_network = test_network();
|
||||
construction_network.reorder_ids();
|
||||
let sorted = construction_network.topological_sort();
|
||||
construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network.");
|
||||
let sorted = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network.");
|
||||
println!("nodes: {:#?}", construction_network.nodes);
|
||||
assert_eq!(sorted, vec![0, 1, 2, 3]);
|
||||
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect();
|
||||
|
@ -710,9 +721,9 @@ mod test {
|
|||
#[test]
|
||||
fn id_reordering_idempotent() {
|
||||
let mut construction_network = test_network();
|
||||
construction_network.reorder_ids();
|
||||
construction_network.reorder_ids();
|
||||
let sorted = construction_network.topological_sort();
|
||||
construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network.");
|
||||
construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network.");
|
||||
let sorted = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network.");
|
||||
assert_eq!(sorted, vec![0, 1, 2, 3]);
|
||||
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect();
|
||||
println!("{:#?}", ids);
|
||||
|
@ -723,7 +734,7 @@ mod test {
|
|||
#[test]
|
||||
fn input_resolution() {
|
||||
let mut construction_network = test_network();
|
||||
construction_network.resolve_inputs();
|
||||
construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network.");
|
||||
println!("{:#?}", construction_network);
|
||||
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value");
|
||||
assert_eq!(construction_network.nodes.len(), 6);
|
||||
|
@ -733,7 +744,7 @@ mod test {
|
|||
#[test]
|
||||
fn stable_node_id_generation() {
|
||||
let mut construction_network = test_network();
|
||||
construction_network.resolve_inputs();
|
||||
construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network.");
|
||||
construction_network.generate_stable_node_ids();
|
||||
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value");
|
||||
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect();
|
||||
|
@ -810,4 +821,35 @@ mod test {
|
|||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn test_network_with_cycles() -> ProtoNetwork {
|
||||
ProtoNetwork {
|
||||
inputs: vec![1],
|
||||
output: 1,
|
||||
nodes: [
|
||||
(
|
||||
1,
|
||||
ProtoNode {
|
||||
identifier: "id".into(),
|
||||
input: ProtoNodeInput::Node(2, false),
|
||||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
document_node_path: vec![],
|
||||
skip_deduplication: false,
|
||||
},
|
||||
),
|
||||
(
|
||||
2,
|
||||
ProtoNode {
|
||||
identifier: "id".into(),
|
||||
input: ProtoNodeInput::Node(1, false),
|
||||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
document_node_path: vec![],
|
||||
skip_deduplication: false,
|
||||
},
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,10 +26,10 @@ pub struct GpuCompiler<TypingContext, ShaderIO> {
|
|||
|
||||
// TODO: Move to graph-craft
|
||||
#[node_macro::node_fn(GpuCompiler)]
|
||||
async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> compilation_client::Shader {
|
||||
async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> Result<compilation_client::Shader, String> {
|
||||
let compiler = graph_craft::graphene_compiler::Compiler {};
|
||||
let DocumentNodeImplementation::Network(ref network) = node.implementation else { panic!() };
|
||||
let proto_networks: Vec<_> = compiler.compile(network.clone()).collect();
|
||||
let proto_networks: Vec<_> = compiler.compile(network.clone())?.collect();
|
||||
|
||||
for network in proto_networks.iter() {
|
||||
typing_context.update(network).expect("Failed to type check network");
|
||||
|
@ -43,7 +43,7 @@ async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingConte
|
|||
.collect();
|
||||
let output_types = proto_networks.iter().map(|network| typing_context.type_of(network.output).unwrap().output.clone()).collect();
|
||||
|
||||
compilation_client::compile(proto_networks, input_types, output_types, io).await.unwrap()
|
||||
Ok(compilation_client::compile(proto_networks, input_types, output_types, io).await.unwrap())
|
||||
}
|
||||
|
||||
pub struct MapGpuNode<Node, EditorApi> {
|
||||
|
@ -97,7 +97,10 @@ async fn map_gpu<'a: 'input>(image: ImageFrame<Color>, node: DocumentNode, edito
|
|||
self.cache.borrow().get(&node.name).unwrap().clone()
|
||||
} else {
|
||||
let name = node.name.clone();
|
||||
let compute_pass_descriptor = create_compute_pass_descriptor(node, &image, executor, quantization).await;
|
||||
let Ok(compute_pass_descriptor) = create_compute_pass_descriptor(node, &image, executor, quantization).await else {
|
||||
log::error!("Error creating compute pass descriptor in 'map_gpu()");
|
||||
return ImageFrame::empty();
|
||||
};
|
||||
self.cache.borrow_mut().insert(name, compute_pass_descriptor.clone());
|
||||
log::error!("created compute pass");
|
||||
compute_pass_descriptor
|
||||
|
@ -156,7 +159,7 @@ async fn create_compute_pass_descriptor<T: Clone + Pixel + StaticTypeSized>(
|
|||
image: &ImageFrame<T>,
|
||||
executor: &&WgpuExecutor,
|
||||
quantization: QuantizationChannels,
|
||||
) -> ComputePass<WgpuExecutor> {
|
||||
) -> Result<ComputePass<WgpuExecutor>, String> {
|
||||
let compiler = graph_craft::graphene_compiler::Compiler {};
|
||||
let inner_network = NodeNetwork::value_network(node);
|
||||
|
||||
|
@ -246,7 +249,7 @@ async fn create_compute_pass_descriptor<T: Clone + Pixel + StaticTypeSized>(
|
|||
..Default::default()
|
||||
};
|
||||
log::debug!("compiling network");
|
||||
let proto_networks = compiler.compile(network.clone()).collect();
|
||||
let proto_networks = compiler.compile(network.clone())?.collect();
|
||||
log::debug!("compiling shader");
|
||||
let shader = compilation_client::compile(
|
||||
proto_networks,
|
||||
|
@ -344,10 +347,10 @@ async fn create_compute_pass_descriptor<T: Clone + Pixel + StaticTypeSized>(
|
|||
};
|
||||
log::debug!("created pipeline");
|
||||
|
||||
ComputePass {
|
||||
Ok(ComputePass {
|
||||
pipeline_layout: pipeline,
|
||||
readback_buffer: Some(readback_buffer.clone()),
|
||||
}
|
||||
})
|
||||
}
|
||||
/*
|
||||
#[node_macro::node_fn(MapGpuNode)]
|
||||
|
@ -417,7 +420,7 @@ pub struct BlendGpuImageNode<Background, B, O> {
|
|||
async fn blend_gpu_image(foreground: ImageFrame<Color>, background: ImageFrame<Color>, blend_mode: BlendMode, opacity: f32) -> ImageFrame<Color> {
|
||||
let foreground_size = DVec2::new(foreground.image.width as f64, foreground.image.height as f64);
|
||||
let background_size = DVec2::new(background.image.width as f64, background.image.height as f64);
|
||||
// Transforms a point from the background image to the forground image
|
||||
// Transforms a point from the background image to the foreground image
|
||||
let bg_to_fg = DAffine2::from_scale(foreground_size) * foreground.transform.inverse() * background.transform * DAffine2::from_scale(1. / background_size);
|
||||
|
||||
let transform_matrix: Mat2 = bg_to_fg.matrix2.as_mat2();
|
||||
|
@ -464,7 +467,11 @@ async fn blend_gpu_image(foreground: ImageFrame<Color>, background: ImageFrame<C
|
|||
..Default::default()
|
||||
};
|
||||
log::debug!("compiling network");
|
||||
let proto_networks = compiler.compile(network.clone()).collect();
|
||||
let Ok(proto_networks_result) = compiler.compile(network.clone()) else {
|
||||
log::error!("Error compiling network in 'blend_gpu_image()");
|
||||
return ImageFrame::empty();
|
||||
};
|
||||
let proto_networks = proto_networks_result.collect();
|
||||
log::debug!("compiling shader");
|
||||
|
||||
let shader = compilation_client::compile(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue