Move Imaginate cache into the node

This commit is contained in:
Dennis Kobert 2023-07-14 16:40:56 +02:00 committed by Keavon Chambers
parent d52ea18a1f
commit 4c9daadb01
3 changed files with 18 additions and 12 deletions

View file

@ -1778,7 +1778,7 @@ pub static IMAGINATE_NODE: Lazy<DocumentNodeType> = Lazy::new(|| DocumentNodeTyp
name: "Imaginate",
category: "Image Synthesis",
identifier: NodeImplementation::DocumentNode(NodeNetwork {
inputs: vec![0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
inputs: vec![0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
outputs: vec![NodeOutput::new(1, 0)],
nodes: [
(
@ -1813,7 +1813,6 @@ pub static IMAGINATE_NODE: Lazy<DocumentNodeType> = Lazy::new(|| DocumentNodeTyp
NodeInput::Network(concrete!(ImaginateMaskStartingFill)),
NodeInput::Network(concrete!(bool)),
NodeInput::Network(concrete!(bool)),
NodeInput::Network(concrete!(ImaginateCache)),
],
implementation: DocumentNodeImplementation::proto("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
..Default::default()
@ -1846,7 +1845,6 @@ pub static IMAGINATE_NODE: Lazy<DocumentNodeType> = Lazy::new(|| DocumentNodeTyp
DocumentInputType::value("Mask Starting Fill", TaggedValue::ImaginateMaskStartingFill(ImaginateMaskStartingFill::Fill), false),
DocumentInputType::value("Improve Faces", TaggedValue::Bool(false), false),
DocumentInputType::value("Tiling", TaggedValue::Bool(false), false),
DocumentInputType::value("Cache", TaggedValue::ImaginateCache(Default::default()), false),
],
outputs: vec![DocumentOutputType::new("Image", FrontendGraphDataType::Raster)],
properties: node_properties::imaginate_properties,

View file

@ -10,7 +10,9 @@ use graphene_core::raster::bbox::{AxisAlignedBbox, Bbox};
use graphene_core::value::CopiedNode;
use graphene_core::{Color, Node};
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::marker::PhantomData;
use std::path::Path;
@ -424,7 +426,7 @@ macro_rules! generate_imaginate_node {
editor_api: E,
controller: C,
$($val: $t,)*
cache: std::sync::Arc<std::sync::Mutex<Image<P>>>,
cache: std::sync::Mutex<HashMap<u64, Image<P>>>,
}
impl<'e, P: Pixel, E, C, $($t,)*> ImaginateNode<P, E, C, $($t,)*>
@ -432,12 +434,12 @@ macro_rules! generate_imaginate_node {
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, WasmEditorApi<'e>>>,
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
{
pub fn new(editor_api: E, controller: C, $($val: $t,)* cache: std::sync::Arc<std::sync::Mutex<Image<P>>>) -> Self {
Self { editor_api, controller, $($val,)* cache }
pub fn new(editor_api: E, controller: C, $($val: $t,)* ) -> Self {
Self { editor_api, controller, $($val,)* cache: Default::default() }
}
}
impl<'i, 'e: 'i, P: Pixel + 'i, E: 'i, C: 'i, $($t: 'i,)*> Node<'i, ImageFrame<P>> for ImaginateNode<P, E, C, $($t,)*>
impl<'i, 'e: 'i, P: Pixel + 'i + Hash + Default, E: 'i, C: 'i, $($t: 'i,)*> Node<'i, ImageFrame<P>> for ImaginateNode<P, E, C, $($t,)*>
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, WasmEditorApi<'e>>>,
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
@ -447,19 +449,27 @@ macro_rules! generate_imaginate_node {
fn eval(&'i self, frame: ImageFrame<P>) -> Self::Output {
let controller = self.controller.eval(());
$(let $val = self.$val.eval(());)*
use std::hash::Hasher;
use xxhash_rust::xxh3::Xxh3;
let mut hasher = Xxh3::new();
frame.hash(&mut hasher);
let hash =hasher.finish();
Box::pin(async move {
let controller: std::pin::Pin<Box<dyn std::future::Future<Output = ImaginateController>>> = controller;
let controller: ImaginateController = controller.await;
if controller.take_regenerate_trigger() {
let editor_api = self.editor_api.eval(());
let image = super::imaginate::imaginate(frame.image, editor_api, controller, $($val,)*).await;
self.cache.lock().unwrap().clone_from(&image);
self.cache.lock().unwrap().insert(hash, image.clone());
return ImageFrame {
image,
..frame
}
}
let image = self.cache.lock().unwrap().clone();
let image = self.cache.lock().unwrap().get(&hash).cloned().unwrap_or_default();
ImageFrame {
image,
..frame

View file

@ -467,9 +467,8 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
|args: Vec<Arc<graph_craft::proto::NodeContainer>>| {
Box::pin(async move {
use graphene_std::raster::ImaginateNode;
let cache: ImaginateCache = graphene_std::any::input_node(args.last().unwrap().clone()).eval(()).await;
macro_rules! instanciate_imaginate_node {
($($i:expr,)*) => { ImaginateNode::new($(graphene_std::any::input_node(args[$i].clone()),)* cache.into_inner()) };
($($i:expr,)*) => { ImaginateNode::new($(graphene_std::any::input_node(args[$i].clone()),)* ) };
}
let node: ImaginateNode<Color, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _> = instanciate_imaginate_node!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,);
let any = graphene_std::any::DynAnyNode::new(node);
@ -497,7 +496,6 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
fn_type!(ImaginateMaskStartingFill),
fn_type!(bool),
fn_type!(bool),
fn_type!(ImaginateCache),
],
),
),