Restore the Imaginate node with the full node graph architecture (but a flaky deadlock remains) (#1908)

* Rework imaginate trigger mechanism

* Fix imaginate generation
This commit is contained in:
Dennis Kobert 2024-08-07 12:23:00 +02:00 committed by GitHub
parent 8041b1237c
commit 06a409f1c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 85 additions and 36 deletions

View file

@ -25,7 +25,7 @@ impl std::cmp::PartialEq for ImaginateCache {
impl core::hash::Hash for ImaginateCache {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.0.lock().unwrap().hash(state);
self.0.try_lock().map(|g| g.hash(state));
}
}
@ -50,11 +50,11 @@ pub struct ImaginateController(Arc<InternalImaginateControl>);
impl ImaginateController {
pub fn get_status(&self) -> ImaginateStatus {
self.0.status.lock().as_deref().cloned().unwrap_or_default()
self.0.status.try_lock().as_deref().cloned().unwrap_or_default()
}
pub fn set_status(&self, status: ImaginateStatus) {
if let Ok(mut lock) = self.0.status.lock() {
if let Ok(mut lock) = self.0.status.try_lock() {
*lock = status
}
}
@ -68,13 +68,13 @@ impl ImaginateController {
}
pub fn request_termination(&self) {
if let Some(handle) = self.0.termination_sender.lock().ok().and_then(|mut lock| lock.take()) {
if let Some(handle) = self.0.termination_sender.try_lock().ok().and_then(|mut lock| lock.take()) {
handle.terminate()
}
}
pub fn set_termination_handle<H: ImaginateTerminationHandle>(&self, handle: Box<H>) {
if let Ok(mut lock) = self.0.termination_sender.lock() {
if let Ok(mut lock) = self.0.termination_sender.try_lock() {
*lock = Some(handle)
}
}

View file

@ -500,28 +500,31 @@ fn base64_to_image<D: AsRef<[u8]>, P: Pixel>(base64_data: D) -> Result<Image<P>,
}
pub fn pick_safe_imaginate_resolution((width, height): (f64, f64)) -> (u64, u64) {
const NATIVE_MODEL_RESOLUTION: f64 = 512.;
let size = if width * height == 0. { DVec2::splat(NATIVE_MODEL_RESOLUTION) } else { DVec2::new(width, height) };
const MAX_RESOLUTION: u64 = 1000 * 1000;
// this is the maximum width/height that can be obtained
// This is the maximum width/height that can be obtained
const MAX_DIMENSION: u64 = (MAX_RESOLUTION / 64) & !63;
// round the resolution to the nearest multiple of 64
let size = (DVec2::new(width, height).round().clamp(DVec2::ZERO, DVec2::splat(MAX_DIMENSION as _)).as_u64vec2() + U64Vec2::splat(32)).max(U64Vec2::splat(64)) & !U64Vec2::splat(63);
// Round the resolution to the nearest multiple of 64
let size = (size.round().clamp(DVec2::ZERO, DVec2::splat(MAX_DIMENSION as _)).as_u64vec2() + U64Vec2::splat(32)).max(U64Vec2::splat(64)) & !U64Vec2::splat(63);
let resolution = size.x * size.y;
if resolution > MAX_RESOLUTION {
// scale down the image, so it is smaller than MAX_RESOLUTION
// Scale down the image, so it is smaller than MAX_RESOLUTION
let scale = (MAX_RESOLUTION as f64 / resolution as f64).sqrt();
let size = size.as_dvec2() * scale;
if size.x < 64.0 {
// the image is extremely wide
// The image is extremely wide
(64, MAX_DIMENSION)
} else if size.y < 64.0 {
// the image is extremely high
// The image is extremely high
(MAX_DIMENSION, 64)
} else {
// round down to a multiple of 64, so that the resolution still is smaller than MAX_RESOLUTION
// Round down to a multiple of 64, so that the resolution still is smaller than MAX_RESOLUTION
(size.as_u64vec2() & !U64Vec2::splat(63)).into()
}
} else {

View file

@ -474,28 +474,32 @@ fn empty_image<_P: Pixel>(transform: DAffine2, color: _P) -> ImageFrame<_P> {
#[cfg(feature = "serde")]
macro_rules! generate_imaginate_node {
($($val:ident: $t:ident: $o:ty,)*) => {
pub struct ImaginateNode<P: Pixel, E, C, $($t,)*> {
pub struct ImaginateNode<P: Pixel, E, C, G, $($t,)*> {
editor_api: E,
controller: C,
generation_id: G,
$($val: $t,)*
cache: std::sync::Arc<std::sync::Mutex<HashMap<u64, Image<P>>>>,
last_generation: std::sync::atomic::AtomicU64,
}
impl<'e, P: Pixel, E, C, $($t,)*> ImaginateNode<P, E, C, $($t,)*>
impl<'e, P: Pixel, E, C, G, $($t,)*> ImaginateNode<P, E, C, G, $($t,)*>
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, &'e WasmEditorApi>>,
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
G: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, u64>>,
{
#[allow(clippy::too_many_arguments)]
pub fn new(editor_api: E, controller: C, $($val: $t,)* ) -> Self {
Self { editor_api, controller, $($val,)* cache: Default::default() }
pub fn new(editor_api: E, controller: C, $($val: $t,)* generation_id: G ) -> Self {
Self { editor_api, controller, generation_id, $($val,)* cache: Default::default(), last_generation: std::sync::atomic::AtomicU64::new(u64::MAX) }
}
}
impl<'i, 'e: 'i, P: Pixel + 'i + Hash + Default + Send, E: 'i, C: 'i, $($t: 'i,)*> Node<'i, ImageFrame<P>> for ImaginateNode<P, E, C, $($t,)*>
impl<'i, 'e: 'i, P: Pixel + 'i + Hash + Default + Send, E: 'i, C: 'i, G: 'i, $($t: 'i,)*> Node<'i, ImageFrame<P>> for ImaginateNode<P, E, C, G, $($t,)*>
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, &'e WasmEditorApi>>,
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
G: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, u64>>,
{
type Output = DynFuture<'i, ImageFrame<P>>;
@ -509,26 +513,45 @@ macro_rules! generate_imaginate_node {
let hash = hasher.finish();
let editor_api = self.editor_api.eval(());
let cache = self.cache.clone();
let generation_future = self.generation_id.eval(());
let last_generation = &self.last_generation;
Box::pin(async move {
// let controller: std::pin::Pin<Box<dyn std::future::Future<Output = ImaginateController> + Send>> = controller;
let controller: ImaginateController = controller.await;
if controller.take_regenerate_trigger() {
let generation_id = generation_future.await;
if generation_id != last_generation.swap(generation_id, std::sync::atomic::Ordering::SeqCst) {
let image = super::imaginate::imaginate(frame.image, editor_api, controller, $($val,)*).await;
cache.lock().unwrap().insert(hash, image.clone());
return ImageFrame { image, ..frame }
return wrap_image_frame(image, frame.transform);
}
let image = cache.lock().unwrap().get(&hash).cloned().unwrap_or_default();
ImageFrame { image, ..frame }
return wrap_image_frame(image, frame.transform);
})
}
}
}
}
fn wrap_image_frame<P: Pixel>(image: Image<P>, transform: DAffine2) -> ImageFrame<P> {
if !transform.decompose_scale().abs_diff_eq(DVec2::ZERO, 0.00001) {
ImageFrame {
image,
transform,
alpha_blending: AlphaBlending::default(),
}
} else {
let resolution = DVec2::new(image.height as f64, image.width as f64);
ImageFrame {
image,
transform: DAffine2::from_scale_angle_translation(resolution, 0., transform.translation),
alpha_blending: AlphaBlending::default(),
}
}
}
#[cfg(feature = "serde")]
generate_imaginate_node! {
seed: Seed: f64,

View file

@ -568,14 +568,14 @@ fn node_registry() -> HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeCons
raster_node!(graphene_core::raster::PosterizeNode<_>, params: [f64]),
raster_node!(graphene_core::raster::ExposureNode<_, _, _>, params: [f64, f64, f64]),
vec![(
ProtoNodeIdentifier::new("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
ProtoNodeIdentifier::new("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
|args: Vec<graph_craft::proto::SharedNodeContainer>| {
Box::pin(async move {
use graphene_std::raster::ImaginateNode;
macro_rules! instantiate_imaginate_node {
($($i:expr,)*) => { ImaginateNode::new($(graphene_std::any::input_node(args[$i].clone()),)* ) };
}
let node: ImaginateNode<Color, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _> = instantiate_imaginate_node!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,);
let node: ImaginateNode<Color, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _> = instantiate_imaginate_node!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,);
let any = graphene_std::any::DynAnyNode::new(node);
any.into_type_erased()
})
@ -584,9 +584,9 @@ fn node_registry() -> HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeCons
concrete!(ImageFrame<Color>),
concrete!(ImageFrame<Color>),
vec![
fn_type!(WasmEditorApi),
fn_type!(&WasmEditorApi),
fn_type!(ImaginateController),
fn_type!(u64),
fn_type!(f64),
fn_type!(Option<DVec2>),
fn_type!(u32),
fn_type!(ImaginateSamplingMethod),
@ -600,6 +600,7 @@ fn node_registry() -> HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeCons
fn_type!(ImaginateMaskStartingFill),
fn_type!(bool),
fn_type!(bool),
fn_type!(u64),
],
),
)],