mirror of
https://github.com/GraphiteEditor/Graphite.git
synced 2025-08-04 13:30:48 +00:00
Remove unsafe from node graph evaluation
This commit is contained in:
parent
9665fa0b47
commit
88fa67e2ff
8 changed files with 92 additions and 98 deletions
|
@ -36,6 +36,15 @@ unsafe impl StaticType for SurfaceFrame {
|
|||
type Static = SurfaceFrame;
|
||||
}
|
||||
|
||||
impl<'a, S> From<SurfaceHandleFrame<'a, S>> for SurfaceFrame {
|
||||
fn from(x: SurfaceHandleFrame<'a, S>) -> Self {
|
||||
Self {
|
||||
surface_id: x.surface_handle.surface_id,
|
||||
transform: x.transform,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SurfaceHandle<'a, Surface> {
|
||||
pub surface_id: SurfaceId,
|
||||
|
|
|
@ -245,56 +245,60 @@ impl<'a> TaggedValue {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn try_from_any(input: Box<dyn DynAny<'a> + 'a>) -> Option<Self> {
|
||||
pub fn try_from_any(input: Box<dyn DynAny<'a> + 'a>) -> Result<Self, String> {
|
||||
use dyn_any::downcast;
|
||||
use std::any::TypeId;
|
||||
|
||||
match DynAny::type_id(input.as_ref()) {
|
||||
x if x == TypeId::of::<()>() => Some(TaggedValue::None),
|
||||
x if x == TypeId::of::<String>() => Some(TaggedValue::String(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<u32>() => Some(TaggedValue::U32(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<f32>() => Some(TaggedValue::F32(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<f64>() => Some(TaggedValue::F64(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<bool>() => Some(TaggedValue::Bool(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<DVec2>() => Some(TaggedValue::DVec2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<DVec2>>() => Some(TaggedValue::OptionalDVec2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::Image<Color>>() => Some(TaggedValue::Image(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<Arc<graphene_core::raster::Image<Color>>>>() => Some(TaggedValue::RcImage(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::ImageFrame<Color>>() => Some(TaggedValue::ImageFrame(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::Color>() => Some(TaggedValue::Color(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>>() => Some(TaggedValue::Subpaths(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Arc<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>>() => Some(TaggedValue::RcSubpath(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<BlendMode>() => Some(TaggedValue::BlendMode(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateSamplingMethod>() => Some(TaggedValue::ImaginateSamplingMethod(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateMaskStartingFill>() => Some(TaggedValue::ImaginateMaskStartingFill(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateStatus>() => Some(TaggedValue::ImaginateStatus(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<Vec<u64>>>() => Some(TaggedValue::LayerPath(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<DAffine2>() => Some(TaggedValue::DAffine2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<LuminanceCalculation>() => Some(TaggedValue::LuminanceCalculation(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::VectorData>() => Some(TaggedValue::VectorData(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::Fill>() => Some(TaggedValue::Fill(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::Stroke>() => Some(TaggedValue::Stroke(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<f32>>() => Some(TaggedValue::VecF32(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::RedGreenBlue>() => Some(TaggedValue::RedGreenBlue(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::RelativeAbsolute>() => Some(TaggedValue::RelativeAbsolute(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::SelectiveColorChoice>() => Some(TaggedValue::SelectiveColorChoice(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::LineCap>() => Some(TaggedValue::LineCap(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::LineJoin>() => Some(TaggedValue::LineJoin(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::FillType>() => Some(TaggedValue::FillType(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::GradientType>() => Some(TaggedValue::GradientType(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<(f64, Option<graphene_core::Color>)>>() => Some(TaggedValue::GradientPositions(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::quantization::QuantizationChannels>() => Some(TaggedValue::Quantization(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<graphene_core::Color>>() => Some(TaggedValue::OptionalColor(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<graphene_core::uuid::ManipulatorGroupId>>() => Some(TaggedValue::ManipulatorGroupIds(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::text::Font>() => Some(TaggedValue::Font(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<graphene_core::vector::brush_stroke::BrushStroke>>() => Some(TaggedValue::BrushStrokes(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::IndexNode<Vec<graphene_core::raster::ImageFrame<Color>>>>() => Some(TaggedValue::Segments(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<crate::document::DocumentNode>() => Some(TaggedValue::DocumentNode(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::GraphicGroup>() => Some(TaggedValue::GraphicGroup(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::Artboard>() => Some(TaggedValue::Artboard(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<glam::IVec2>() => Some(TaggedValue::IVec2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::SurfaceFrame>() => Some(TaggedValue::SurfaceFrame(*downcast(input).unwrap())),
|
||||
_ => None,
|
||||
x if x == TypeId::of::<()>() => Ok(TaggedValue::None),
|
||||
x if x == TypeId::of::<String>() => Ok(TaggedValue::String(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<u32>() => Ok(TaggedValue::U32(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<f32>() => Ok(TaggedValue::F32(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<f64>() => Ok(TaggedValue::F64(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<bool>() => Ok(TaggedValue::Bool(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<DVec2>() => Ok(TaggedValue::DVec2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<DVec2>>() => Ok(TaggedValue::OptionalDVec2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::Image<Color>>() => Ok(TaggedValue::Image(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<Arc<graphene_core::raster::Image<Color>>>>() => Ok(TaggedValue::RcImage(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::ImageFrame<Color>>() => Ok(TaggedValue::ImageFrame(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::Color>() => Ok(TaggedValue::Color(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>>() => Ok(TaggedValue::Subpaths(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Arc<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>>() => Ok(TaggedValue::RcSubpath(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<BlendMode>() => Ok(TaggedValue::BlendMode(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateSamplingMethod>() => Ok(TaggedValue::ImaginateSamplingMethod(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateMaskStartingFill>() => Ok(TaggedValue::ImaginateMaskStartingFill(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateStatus>() => Ok(TaggedValue::ImaginateStatus(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<Vec<u64>>>() => Ok(TaggedValue::LayerPath(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<DAffine2>() => Ok(TaggedValue::DAffine2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<LuminanceCalculation>() => Ok(TaggedValue::LuminanceCalculation(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::VectorData>() => Ok(TaggedValue::VectorData(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::Fill>() => Ok(TaggedValue::Fill(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::Stroke>() => Ok(TaggedValue::Stroke(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<f32>>() => Ok(TaggedValue::VecF32(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::RedGreenBlue>() => Ok(TaggedValue::RedGreenBlue(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::RelativeAbsolute>() => Ok(TaggedValue::RelativeAbsolute(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::SelectiveColorChoice>() => Ok(TaggedValue::SelectiveColorChoice(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::LineCap>() => Ok(TaggedValue::LineCap(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::LineJoin>() => Ok(TaggedValue::LineJoin(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::FillType>() => Ok(TaggedValue::FillType(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::vector::style::GradientType>() => Ok(TaggedValue::GradientType(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<(f64, Option<graphene_core::Color>)>>() => Ok(TaggedValue::GradientPositions(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::quantization::QuantizationChannels>() => Ok(TaggedValue::Quantization(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<graphene_core::Color>>() => Ok(TaggedValue::OptionalColor(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<graphene_core::uuid::ManipulatorGroupId>>() => Ok(TaggedValue::ManipulatorGroupIds(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::text::Font>() => Ok(TaggedValue::Font(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<graphene_core::vector::brush_stroke::BrushStroke>>() => Ok(TaggedValue::BrushStrokes(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::IndexNode<Vec<graphene_core::raster::ImageFrame<Color>>>>() => Ok(TaggedValue::Segments(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<crate::document::DocumentNode>() => Ok(TaggedValue::DocumentNode(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::GraphicGroup>() => Ok(TaggedValue::GraphicGroup(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::Artboard>() => Ok(TaggedValue::Artboard(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<glam::IVec2>() => Ok(TaggedValue::IVec2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::SurfaceFrame>() => Ok(TaggedValue::SurfaceFrame(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::wasm_application_io::WasmSurfaceHandleFrame>() => {
|
||||
let frame = *downcast::<graphene_core::wasm_application_io::WasmSurfaceHandleFrame>(input).unwrap();
|
||||
Ok(TaggedValue::SurfaceFrame(frame.into()))
|
||||
}
|
||||
_ => Err(format!("Cannot convert {:?} to TaggedValue", DynAny::type_name(input.as_ref()))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::error::Error;
|
||||
|
||||
use dyn_any::DynAny;
|
||||
use dyn_any::{DynAny, StaticType};
|
||||
|
||||
use crate::document::NodeNetwork;
|
||||
use crate::proto::{LocalFuture, ProtoNetwork};
|
||||
|
@ -38,6 +38,6 @@ impl Compiler {
|
|||
}
|
||||
pub type Any<'a> = Box<dyn DynAny<'a> + 'a>;
|
||||
|
||||
pub trait Executor {
|
||||
fn execute<'a>(&'a self, input: Any<'a>) -> LocalFuture<Result<Any<'a>, Box<dyn Error>>>;
|
||||
pub trait Executor<I, O> {
|
||||
fn execute(&self, input: I) -> LocalFuture<Result<O, Box<dyn Error>>>;
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::error::Error;
|
|||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use dyn_any::StaticType;
|
||||
use graph_craft::document::value::UpcastNode;
|
||||
use graph_craft::document::value::{TaggedValue, UpcastNode};
|
||||
use graph_craft::document::NodeId;
|
||||
use graph_craft::executor::Executor;
|
||||
use graph_craft::proto::{ConstructionArgs, LocalFuture, ProtoNetwork, ProtoNode, TypingContext};
|
||||
|
@ -73,9 +73,9 @@ impl DynamicExecutor {
|
|||
}
|
||||
}
|
||||
|
||||
impl Executor for DynamicExecutor {
|
||||
fn execute<'a>(&'a self, input: Any<'a>) -> LocalFuture<Result<Any<'a>, Box<dyn Error>>> {
|
||||
Box::pin(async move { self.tree.eval_any(self.output, input).await.ok_or_else(|| "Failed to execute".into()) })
|
||||
impl<'a, I: StaticType + 'a> Executor<I, TaggedValue> for &'a DynamicExecutor {
|
||||
fn execute(&self, input: I) -> LocalFuture<Result<TaggedValue, Box<dyn Error>>> {
|
||||
Box::pin(async move { self.tree.eval_tagged_value(self.output, input).await.map_err(|e| e.into()) })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -167,18 +167,17 @@ impl BorrowTree {
|
|||
self.nodes.get(&id).cloned()
|
||||
}
|
||||
|
||||
pub async fn eval<'i, I: StaticType + 'i + Send + Sync, O: StaticType + Send + Sync + 'i>(&'i self, id: NodeId, input: I) -> Option<O> {
|
||||
pub async fn eval<'i, I: StaticType + 'i, O: StaticType + 'i>(&'i self, id: NodeId, input: I) -> Option<O> {
|
||||
let node = self.nodes.get(&id).cloned()?;
|
||||
let reader = node.read().unwrap();
|
||||
let output = reader.node.eval(Box::new(input));
|
||||
dyn_any::downcast::<O>(output.await).ok().map(|o| *o)
|
||||
}
|
||||
pub async fn eval_any<'i>(&'i self, id: NodeId, input: Any<'i>) -> Option<Any<'i>> {
|
||||
let node = self.nodes.get(&id)?;
|
||||
// TODO: Comments by @TrueDoctor before this was merged:
|
||||
// TODO: Oof I dislike the evaluation being an unsafe operation but I guess its fine because it only is a lifetime extension
|
||||
// TODO: We should ideally let miri run on a test that evaluates the nodegraph multiple times to check if this contains any subtle UB but this looks fine for now
|
||||
Some(unsafe { (*((&*node.read().unwrap()) as *const NodeContainer)).node.eval(input).await })
|
||||
pub async fn eval_tagged_value<'i, I: StaticType + 'i>(&'i self, id: NodeId, input: I) -> Result<TaggedValue, String> {
|
||||
let node = self.nodes.get(&id).cloned().ok_or_else(|| "Output node not found in executor")?;
|
||||
let reader = node.read().unwrap();
|
||||
let output = reader.node.eval(Box::new(input));
|
||||
TaggedValue::try_from_any(output.await)
|
||||
}
|
||||
|
||||
pub fn free_node(&mut self, id: NodeId) {
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::error::Error;
|
||||
|
||||
use super::context::Context;
|
||||
|
||||
use graph_craft::executor::{Any, Executor};
|
||||
|
@ -38,9 +40,8 @@ impl<I: StaticTypeSized, O> GpuExecutor<I, O> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Sync + Pod> Executor for GpuExecutor<I, O> {
|
||||
fn execute<'i>(&'i self, input: Any<'i>) -> LocalFuture<Result<Any<'i>, Box<dyn std::error::Error>>> {
|
||||
let input = dyn_any::downcast::<Vec<I>>(input).expect("Wrong input type");
|
||||
impl<'a, I: StaticTypeSized + Sync + Pod + Send + 'a, O: StaticTypeSized + Send + Sync + Pod + 'a> Executor<Vec<I>, Vec<O>> for &'a GpuExecutor<I, O> {
|
||||
fn execute(&self, input: Vec<I>) -> LocalFuture<Result<Vec<O>, Box<dyn Error>>> {
|
||||
let context = &self.context;
|
||||
let result: Vec<O> = execute_shader(
|
||||
context.device.clone(),
|
||||
|
@ -48,9 +49,9 @@ impl<I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Sync +
|
|||
self.shader.entry_point(&self.entry_point).expect("Entry point not found in shader"),
|
||||
&context.allocator,
|
||||
&context.command_buffer_allocator,
|
||||
*input,
|
||||
input,
|
||||
);
|
||||
Box::pin(async move { Ok(Box::new(result) as Any) })
|
||||
Box::pin(async move { Ok(result) })
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
use std::{borrow::Cow, error::Error};
|
||||
use wgpu::util::DeviceExt;
|
||||
|
||||
use super::context::Context;
|
||||
|
@ -29,16 +29,15 @@ impl<'a, I: StaticTypeSized, O> GpuExecutor<'a, I, O> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a, I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Sync + Pod> Executor for GpuExecutor<'a, I, O> {
|
||||
fn execute<'i>(&'i self, input: Any<'i>) -> LocalFuture<Result<Any<'i>, Box<dyn std::error::Error>>> {
|
||||
let input = dyn_any::downcast::<Vec<I>>(input).expect("Wrong input type");
|
||||
impl<'a, I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Sync + Pod> Executor<Vec<I>, Vec<O>> for GpuExecutor<'a, I, O> {
|
||||
fn execute(&self, input: Vec<I>) -> LocalFuture<Result<Vec<O>, Box<dyn Error>>> {
|
||||
let context = &self.context;
|
||||
let future = execute_shader(context.device.clone(), context.queue.clone(), self.shader.to_vec(), *input, self.entry_point.clone());
|
||||
let future = execute_shader(context.device.clone(), context.queue.clone(), self.shader.to_vec(), input, self.entry_point.clone());
|
||||
Box::pin(async move {
|
||||
let result = future.await;
|
||||
|
||||
let result: Vec<O> = result.ok_or_else(|| String::from("Failed to execute shader"))?;
|
||||
Ok(Box::new(result) as Any)
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue