Fix Imaginate by porting its JS roundtrip code to graph-based async execution in Rust (#1250)

* Create asynchronous rust imaginate node

* Make a first imaginate request via rust

* Implement parsing of imaginate API result image

* Stop refresh timer from affecting imaginate progress requests

* Add cargo-about clarification for rustls-webpki

* Delete imaginate.ts and all uses of its functions

* Add imaginate img2img feature

* Fix imaginate random seed button

* Fix imaginate ui inferring non-custom resolutions

* Fix the imaginate progress indicator

* Remove ImaginatePreferences from being compiled into node graph

* Regenerate imaginate only when hitting button

* Add ability to terminate imaginate requests

* Add imaginate server check feature

* Do not compile wasm_bindgen bindings in graphite_editor for tests

* Address some review suggestions

- move wasm futures dependency in editor to the future-executor crate
- guard wasm-bindgen in editor behind a `wasm` feature flag
- dont make seed number input a slider
- remove poll_server_check from process_message function beginning
- guard wasm related code behind `cfg(target_arch = "wasm32")` instead
  of `cfg(test)`
- Call the imaginate idle states "Ready" and "Done" instead of "Nothing
  to do"
- Call the imaginate uploading state "Uploading Image" instead of
  "Uploading Input Image"
- Remove the EvalSyncNode

* Fix imaginate host name being restored between graphite instances

also change the progress status texts a bit.

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
nat-rix 2023-06-09 09:03:15 +02:00 committed by Keavon Chambers
parent a1c70c4d90
commit f76b850b9c
35 changed files with 1500 additions and 1326 deletions

710
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -21,7 +21,12 @@ workarounds = ["ring"]
# is the ISC license but test code within the repo is BSD-3-Clause, but is not compiled into the crate when we use it
[webpki.clarify]
license = "ISC"
[[webpki.clarify.files]]
path = 'LICENSE'
path = "LICENSE"
checksum = "5b698ca13897be3afdb7174256fa1574f8c6892b8bea1a66dd6469d3fe27885a"
[rustls-webpki.clarify]
license = "ISC"
[[rustls-webpki.clarify.files]]
path = "LICENSE"
checksum = "5b698ca13897be3afdb7174256fa1574f8c6892b8bea1a66dd6469d3fe27885a"

View file

@ -983,23 +983,5 @@ pub fn pick_layer_safe_imaginate_resolution(layer: &Layer, render_data: &RenderD
let layer_bounds = layer.bounding_transform(render_data);
let layer_bounds_size = (layer_bounds.transform_vector2((1., 0.).into()).length(), layer_bounds.transform_vector2((0., 1.).into()).length());
pick_safe_imaginate_resolution(layer_bounds_size)
}
pub fn pick_safe_imaginate_resolution((width, height): (f64, f64)) -> (u64, u64) {
const MAX_RESOLUTION: u64 = 1000 * 1000;
let mut scale_factor = 1.;
let round_to_increment = |size: f64| (size / 64.).round() as u64 * 64;
loop {
let possible_solution = (round_to_increment(width * scale_factor), round_to_increment(height * scale_factor));
if possible_solution.0 * possible_solution.1 <= MAX_RESOLUTION {
return possible_solution;
}
scale_factor -= 0.1;
}
graphene_std::imaginate::pick_safe_imaginate_resolution(layer_bounds_size)
}

View file

@ -11,11 +11,13 @@ repository = "https://github.com/GraphiteEditor/Graphite"
license = "Apache-2.0"
[features]
default = ["wasm"]
gpu = ["interpreted-executor/gpu", "graphene-std/gpu", "graphene-core/gpu", "wgpu-executor", "gpu-executor"]
quantization = [
"graphene-std/quantization",
"interpreted-executor/quantization",
]
wasm = ["wasm-bindgen", "future-executor", "graphene-std/wasm"]
[dependencies]
log = "0.4"
@ -45,9 +47,12 @@ gpu-executor = { path = "../node-graph/gpu-executor", optional = true }
interpreted-executor = { path = "../node-graph/interpreted-executor" }
dyn-any = { path = "../libraries/dyn-any" }
graphene-core = { path = "../node-graph/gcore" }
graphene-std = { path = "../node-graph/gstd", features = ["wasm"] }
graphene-std = { path = "../node-graph/gstd" }
future-executor = { path = "../node-graph/future-executor", optional = true }
num_enum = "0.6.1"
wasm-bindgen = { version = "0.2.86", optional = true }
[dependencies.document-legacy]
path = "../document-legacy"
package = "graphite-document-legacy"

View file

@ -9,7 +9,6 @@ use crate::messages::tool::utility_types::HintData;
use document_legacy::LayerId;
use graph_craft::document::NodeId;
use graph_craft::imaginate_input::*;
use graphene_core::raster::color::Color;
use graphene_core::text::Font;
@ -75,40 +74,6 @@ pub enum FrontendMessage {
#[serde(rename = "isDefault")]
is_default: bool,
},
TriggerImaginateCheckServerStatus {
hostname: String,
},
TriggerImaginateGenerate {
parameters: Box<ImaginateGenerationParameters>,
#[serde(rename = "baseImage")]
base_image: Option<Box<ImaginateBaseImage>>,
#[serde(rename = "maskImage")]
mask_image: Option<Box<ImaginateMaskImage>>,
#[serde(rename = "maskPaintMode")]
mask_paint_mode: ImaginateMaskPaintMode,
#[serde(rename = "maskBlurPx")]
mask_blur_px: u32,
#[serde(rename = "maskFillContent")]
imaginate_mask_starting_fill: ImaginateMaskStartingFill,
hostname: String,
#[serde(rename = "refreshFrequency")]
refresh_frequency: f64,
#[serde(rename = "documentId")]
document_id: u64,
#[serde(rename = "layerPath")]
layer_path: Vec<LayerId>,
#[serde(rename = "nodePath")]
node_path: Vec<NodeId>,
},
TriggerImaginateTerminate {
#[serde(rename = "documentId")]
document_id: u64,
#[serde(rename = "layerPath")]
layer_path: Vec<LayerId>,
#[serde(rename = "nodePath")]
node_path: Vec<NodeId>,
hostname: String,
},
TriggerImport,
TriggerIndexedDbRemoveDocument {
#[serde(rename = "documentId")]

View file

@ -93,8 +93,6 @@ pub enum DocumentMessage {
GroupSelectedLayers,
ImaginateClear {
layer_path: Vec<LayerId>,
node_id: NodeId,
cached_index: usize,
},
ImaginateGenerate {
layer_path: Vec<LayerId>,
@ -105,10 +103,6 @@ pub enum DocumentMessage {
imaginate_node: Vec<NodeId>,
then_generate: bool,
},
ImaginateTerminate {
layer_path: Vec<LayerId>,
node_path: Vec<NodeId>,
},
InputFrameRasterizeRegionBelowLayer {
layer_path: Vec<LayerId>,
},

View file

@ -465,15 +465,7 @@ impl MessageHandler<DocumentMessage, (u64, &InputPreprocessorMessageHandler, &Pe
replacement_selected_layers: vec![new_folder_path],
});
}
ImaginateClear {
layer_path,
node_id,
cached_index: input_index,
} => {
let value = graph_craft::document::value::TaggedValue::RcImage(None);
responses.add(NodeGraphMessage::SetInputValue { node_id, input_index, value });
responses.add(InputFrameRasterizeRegionBelowLayer { layer_path });
}
ImaginateClear { layer_path } => responses.add(InputFrameRasterizeRegionBelowLayer { layer_path }),
ImaginateGenerate { layer_path, imaginate_node } => {
if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data, Some(imaginate_node)) {
responses.add(message);
@ -484,12 +476,16 @@ impl MessageHandler<DocumentMessage, (u64, &InputPreprocessorMessageHandler, &Pe
imaginate_node,
then_generate,
} => {
// Generate a random seed. We only want values between -2^53 and 2^53, because integer values
// outside of this range can get rounded in f64
let random_bits = generate_uuid();
let random_value = ((random_bits >> 11) as f64).copysign(f64::from_bits(random_bits & (1 << 63)));
// Set a random seed input
responses.add(NodeGraphMessage::SetInputValue {
node_id: *imaginate_node.last().unwrap(),
// Needs to match the index of the seed parameter in `pub const IMAGINATE_NODE: DocumentNodeType` in `document_node_type.rs`
input_index: 1,
value: graph_craft::document::value::TaggedValue::F64((generate_uuid() >> 1) as f64),
input_index: 3,
value: graph_craft::document::value::TaggedValue::F64(random_value),
});
// Generate the image
@ -497,14 +493,6 @@ impl MessageHandler<DocumentMessage, (u64, &InputPreprocessorMessageHandler, &Pe
responses.add(DocumentMessage::ImaginateGenerate { layer_path, imaginate_node });
}
}
ImaginateTerminate { layer_path, node_path } => {
responses.add(FrontendMessage::TriggerImaginateTerminate {
document_id,
layer_path,
node_path,
hostname: preferences.imaginate_server_hostname.clone(),
});
}
InputFrameRasterizeRegionBelowLayer { layer_path } => {
if layer_path.is_empty() {
responses.add(NodeGraphMessage::RunDocumentGraph);
@ -996,7 +984,7 @@ impl DocumentMessageHandler {
// Check if we use the "Input Frame" node.
// TODO: Remove once rasterization is moved into a node.
let input_frame_node_id = node_network.nodes.iter().find(|(_, node)| node.name == "Input Frame").map(|(&id, _)| id);
let input_frame_connected_to_graph_output = input_frame_node_id.map_or(false, |target_node_id| node_network.connected_to_output(target_node_id, imaginate_node_path.is_none()));
let input_frame_connected_to_graph_output = input_frame_node_id.map_or(false, |target_node_id| node_network.connected_to_output(target_node_id));
// If the Input Frame node is connected upstream, rasterize the artwork below this layer by calling into JS
let response = if input_frame_connected_to_graph_output {

View file

@ -459,7 +459,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
let input = NodeInput::node(output_node, output_node_connector_index);
responses.add(NodeGraphMessage::SetNodeInput { node_id, input_index, input });
let should_rerender = network.connected_to_output(node_id, true);
let should_rerender = network.connected_to_output(node_id);
responses.add(NodeGraphMessage::SendGraph { should_rerender });
}
NodeGraphMessage::Copy => {
@ -517,7 +517,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
if let Some(network) = self.get_active_network(document) {
// Only generate node graph if one of the selected nodes is connected to the output
if self.selected_nodes.iter().any(|&node_id| network.connected_to_output(node_id, true)) {
if self.selected_nodes.iter().any(|&node_id| network.connected_to_output(node_id)) {
if let Some(layer_path) = self.layer_path.clone() {
responses.add(DocumentMessage::InputFrameRasterizeRegionBelowLayer { layer_path });
}
@ -549,7 +549,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
}
responses.add(NodeGraphMessage::SetNodeInput { node_id, input_index, input });
let should_rerender = network.connected_to_output(node_id, true);
let should_rerender = network.connected_to_output(node_id);
responses.add(NodeGraphMessage::SendGraph { should_rerender });
}
NodeGraphMessage::DoubleClickNode { node } => {
@ -626,7 +626,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
}
responses.add(NodeGraphMessage::SetNodeInput { node_id, input_index, input });
let should_rerender = network.connected_to_output(node_id, true);
let should_rerender = network.connected_to_output(node_id);
responses.add(NodeGraphMessage::SendGraph { should_rerender });
responses.add(PropertiesPanelMessage::ResendActiveProperties);
}
@ -743,7 +743,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
let input = NodeInput::Value { tagged_value: value, exposed: false };
responses.add(NodeGraphMessage::SetNodeInput { node_id, input_index, input });
responses.add(PropertiesPanelMessage::ResendActiveProperties);
if (node.name != "Imaginate" || input_index == 0) && network.connected_to_output(node_id, true) {
if (node.name != "Imaginate" || input_index == 0) && network.connected_to_output(node_id) {
if let Some(layer_path) = self.layer_path.clone() {
responses.add(DocumentMessage::InputFrameRasterizeRegionBelowLayer { layer_path });
} else {
@ -780,7 +780,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
node.inputs.extend(((node.inputs.len() - 1)..input_index).map(|_| NodeInput::Network(generic!(T))));
}
node.inputs[input_index] = NodeInput::Value { tagged_value: value, exposed: false };
if network.connected_to_output(*node_id, true) {
if network.connected_to_output(*node_id) {
responses.add(DocumentMessage::InputFrameRasterizeRegionBelowLayer { layer_path });
}
}
@ -854,7 +854,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
Self::send_graph(network, executor, &self.layer_path, responses);
// Only generate node graph if one of the selected nodes is connected to the output
if self.selected_nodes.iter().any(|&node_id| network.connected_to_output(node_id, true)) {
if self.selected_nodes.iter().any(|&node_id| network.connected_to_output(node_id)) {
if let Some(layer_path) = self.layer_path.clone() {
responses.add(DocumentMessage::InputFrameRasterizeRegionBelowLayer { layer_path });
}

View file

@ -1673,12 +1673,63 @@ fn static_nodes() -> Vec<DocumentNodeType> {
pub static IMAGINATE_NODE: Lazy<DocumentNodeType> = Lazy::new(|| DocumentNodeType {
name: "Imaginate",
category: "Image Synthesis",
identifier: NodeImplementation::proto("graphene_std::raster::ImaginateNode<_>"),
identifier: NodeImplementation::DocumentNode(NodeNetwork {
inputs: vec![0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
outputs: vec![NodeOutput::new(1, 0)],
nodes: [
(
0,
DocumentNode {
name: "Frame Monitor".into(),
inputs: vec![NodeInput::Network(concrete!(ImageFrame<Color>))],
implementation: DocumentNodeImplementation::proto("graphene_core::memo::MonitorNode<_>"),
..Default::default()
},
),
(
1,
DocumentNode {
name: "Imaginate".into(),
inputs: vec![
NodeInput::node(0, 0),
NodeInput::Network(concrete!(WasmEditorApi)),
NodeInput::Network(concrete!(ImaginateController)),
NodeInput::Network(concrete!(f64)),
NodeInput::Network(concrete!(Option<DVec2>)),
NodeInput::Network(concrete!(u32)),
NodeInput::Network(concrete!(ImaginateSamplingMethod)),
NodeInput::Network(concrete!(f64)),
NodeInput::Network(concrete!(String)),
NodeInput::Network(concrete!(String)),
NodeInput::Network(concrete!(bool)),
NodeInput::Network(concrete!(f64)),
NodeInput::Network(concrete!(Option<Vec<u64>>)),
NodeInput::Network(concrete!(bool)),
NodeInput::Network(concrete!(f64)),
NodeInput::Network(concrete!(ImaginateMaskStartingFill)),
NodeInput::Network(concrete!(bool)),
NodeInput::Network(concrete!(bool)),
NodeInput::Network(concrete!(ImaginateCache)),
],
implementation: DocumentNodeImplementation::proto("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
..Default::default()
},
),
]
.into(),
..Default::default()
}),
inputs: vec![
DocumentInputType::value("Input Image", TaggedValue::ImageFrame(ImageFrame::empty()), true),
DocumentInputType {
name: "Editor Api",
data_type: FrontendGraphDataType::General,
default: NodeInput::Network(concrete!(WasmEditorApi)),
},
DocumentInputType::value("Controller", TaggedValue::ImaginateController(Default::default()), false),
DocumentInputType::value("Seed", TaggedValue::F64(0.), false), // Remember to keep index used in `ImaginateRandom` updated with this entry's index
DocumentInputType::value("Resolution", TaggedValue::OptionalDVec2(None), false),
DocumentInputType::value("Samples", TaggedValue::F64(30.), false),
DocumentInputType::value("Samples", TaggedValue::U32(30), false),
DocumentInputType::value("Sampling Method", TaggedValue::ImaginateSamplingMethod(ImaginateSamplingMethod::EulerA), false),
DocumentInputType::value("Prompt Guidance", TaggedValue::F64(7.5), false),
DocumentInputType::value("Prompt", TaggedValue::String(String::new()), false),
@ -1691,10 +1742,7 @@ pub static IMAGINATE_NODE: Lazy<DocumentNodeType> = Lazy::new(|| DocumentNodeTyp
DocumentInputType::value("Mask Starting Fill", TaggedValue::ImaginateMaskStartingFill(ImaginateMaskStartingFill::Fill), false),
DocumentInputType::value("Improve Faces", TaggedValue::Bool(false), false),
DocumentInputType::value("Tiling", TaggedValue::Bool(false), false),
// Non-user status (is document input the right way to do this?)
DocumentInputType::value("Cached Data", TaggedValue::RcImage(None), false),
DocumentInputType::value("Percent Complete", TaggedValue::F64(0.), false),
DocumentInputType::value("Status", TaggedValue::ImaginateStatus(ImaginateStatus::Idle), false),
DocumentInputType::value("Cache", TaggedValue::ImaginateCache(Default::default()), false),
],
outputs: vec![DocumentOutputType::new("Image", FrontendGraphDataType::Raster)],
properties: node_properties::imaginate_properties,

View file

@ -5,9 +5,11 @@ use super::FrontendGraphDataType;
use crate::messages::layout::utility_types::widget_prelude::*;
use crate::messages::prelude::*;
use document_legacy::{layers::layer_info::LayerDataTypeDiscriminant, Operation};
use graph_craft::concrete;
use graph_craft::document::value::TaggedValue;
use graph_craft::document::{DocumentNode, NodeId, NodeInput};
use graph_craft::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateServerStatus, ImaginateStatus};
use graphene_core::raster::{BlendMode, Color, ImageFrame, LuminanceCalculation, RedGreenBlue, RelativeAbsolute, SelectiveColorChoice};
use graphene_core::text::Font;
use graphene_core::vector::style::{FillType, GradientType, LineCap, LineJoin};
@ -980,12 +982,17 @@ pub fn node_section_font(document_node: &DocumentNode, node_id: NodeId, _context
result
}
pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _context: &mut NodePropertiesContext) -> Vec<LayoutGroup> {
/*
pub fn imaginate_properties(document_node: &DocumentNode, node_id: NodeId, context: &mut NodePropertiesContext) -> Vec<LayoutGroup> {
let imaginate_node = [context.nested_path, &[node_id]].concat();
let layer_path = context.layer_path.to_vec();
let resolve_input = |name: &str| IMAGINATE_NODE.inputs.iter().position(|input| input.name == name).unwrap_or_else(|| panic!("Input {name} not found"));
let resolve_input = |name: &str| {
super::IMAGINATE_NODE
.inputs
.iter()
.position(|input| input.name == name)
.unwrap_or_else(|| panic!("Input {name} not found"))
};
let seed_index = resolve_input("Seed");
let resolution_index = resolve_input("Resolution");
let samples_index = resolve_input("Samples");
@ -1001,22 +1008,12 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
let mask_fill_index = resolve_input("Mask Starting Fill");
let faces_index = resolve_input("Improve Faces");
let tiling_index = resolve_input("Tiling");
let cached_index = resolve_input("Cached Data");
let cached_value = &document_node.inputs[cached_index];
let complete_value = &document_node.inputs[resolve_input("Percent Complete")];
let status_value = &document_node.inputs[resolve_input("Status")];
let controller = &document_node.inputs[resolve_input("Controller")];
let server_status = {
let status = match &context.persistent_data.imaginate_server_status {
ImaginateServerStatus::Unknown => {
context.responses.add(PortfolioMessage::ImaginateCheckServerStatus);
"Checking..."
}
ImaginateServerStatus::Checking => "Checking...",
ImaginateServerStatus::Unavailable => "Unavailable",
ImaginateServerStatus::Connected => "Connected",
};
let server_status = context.persistent_data.imaginate.server_status();
let status_text = server_status.to_text();
let mut widgets = vec![
WidgetHolder::text_widget("Server"),
WidgetHolder::unrelated_separator(),
@ -1025,14 +1022,14 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
.on_update(|_| DialogMessage::RequestPreferencesDialog.into())
.widget_holder(),
WidgetHolder::unrelated_separator(),
WidgetHolder::bold_text(status),
WidgetHolder::bold_text(status_text),
WidgetHolder::related_separator(),
IconButton::new("Reload", 24)
.tooltip("Refresh connection status")
.on_update(|_| PortfolioMessage::ImaginateCheckServerStatus.into())
.widget_holder(),
];
if context.persistent_data.imaginate_server_status == ImaginateServerStatus::Unavailable {
if let ImaginateServerStatus::Unavailable | ImaginateServerStatus::Failed(_) = server_status {
widgets.extend([
WidgetHolder::unrelated_separator(),
TextButton::new("Server Help")
@ -1049,15 +1046,11 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
LayoutGroup::Row { widgets }.with_tooltip("Connection status to the server that computes generated images")
};
let &NodeInput::Value {tagged_value: TaggedValue::ImaginateStatus( imaginate_status),..} = status_value else {
panic!("Invalid status input")
};
let NodeInput::Value {tagged_value: TaggedValue::RcImage( cached_data),..} = cached_value else {
panic!("Invalid cached image input, received {:?}, index: {}", cached_value, cached_index)
};
let &NodeInput::Value {tagged_value: TaggedValue::F64( percent_complete),..} = complete_value else {
panic!("Invalid percent complete input")
let &NodeInput::Value {tagged_value: TaggedValue::ImaginateController(ref controller),..} = controller else {
panic!("Invalid output status input")
};
let imaginate_status = controller.get_status();
let use_base_image = if let &NodeInput::Value {
tagged_value: TaggedValue::Bool(use_base_image),
..
@ -1071,23 +1064,7 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
let transform_not_connected = false;
let progress = {
// Since we don't serialize the status, we need to derive from other state whether the Idle state is actually supposed to be the Terminated state
let mut interpreted_status = imaginate_status;
if imaginate_status == ImaginateStatus::Idle && cached_data.is_some() && percent_complete > 0. && percent_complete < 100. {
interpreted_status = ImaginateStatus::Terminated;
}
let status = match interpreted_status {
ImaginateStatus::Idle => match cached_data {
Some(_) => "Done".into(),
None => "Ready".into(),
},
ImaginateStatus::Beginning => "Beginning...".into(),
ImaginateStatus::Uploading(percent) => format!("Uploading Input Image: {percent:.0}%"),
ImaginateStatus::Generating => format!("Generating: {percent_complete:.0}%"),
ImaginateStatus::Terminating => "Terminating...".into(),
ImaginateStatus::Terminated => format!("{percent_complete:.0}% (Terminated)"),
};
let status = imaginate_status.to_text();
let widgets = vec![
WidgetHolder::text_widget("Progress"),
WidgetHolder::unrelated_separator(),
@ -1095,38 +1072,38 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
WidgetHolder::unrelated_separator(), // TODO: which is the width of the Assist area.
WidgetHolder::unrelated_separator(), // TODO: Remove these when we have proper entry row formatting that includes room for Assists.
WidgetHolder::unrelated_separator(),
WidgetHolder::bold_text(status),
WidgetHolder::bold_text(status.as_ref()),
];
LayoutGroup::Row { widgets }.with_tooltip("When generating, the percentage represents how many sampling steps have so far been processed out of the target number")
LayoutGroup::Row { widgets }.with_tooltip(match imaginate_status {
ImaginateStatus::Failed(_) => status.as_ref(),
_ => "When generating, the percentage represents how many sampling steps have so far been processed out of the target number",
})
};
let image_controls = {
let image_controls: _ = {
let mut widgets = vec![WidgetHolder::text_widget("Image"), WidgetHolder::unrelated_separator()];
let assist_separators = vec![
let assist_separators = [
WidgetHolder::unrelated_separator(), // TODO: These three separators add up to 24px,
WidgetHolder::unrelated_separator(), // TODO: which is the width of the Assist area.
WidgetHolder::unrelated_separator(), // TODO: Remove these when we have proper entry row formatting that includes room for Assists.
WidgetHolder::unrelated_separator(),
];
match imaginate_status {
ImaginateStatus::Beginning | ImaginateStatus::Uploading(_) => {
match &imaginate_status {
ImaginateStatus::Beginning | ImaginateStatus::Uploading => {
widgets.extend_from_slice(&assist_separators);
widgets.push(TextButton::new("Beginning...").tooltip("Sending image generation request to the server").disabled(true).widget_holder());
}
ImaginateStatus::Generating => {
ImaginateStatus::Generating(_) => {
widgets.extend_from_slice(&assist_separators);
widgets.push(
TextButton::new("Terminate")
.tooltip("Cancel the in-progress image generation and keep the latest progress")
.on_update({
let imaginate_node = imaginate_node.clone();
let controller = controller.clone();
move |_| {
DocumentMessage::ImaginateTerminate {
layer_path: layer_path.clone(),
node_path: imaginate_node.clone(),
}
.into()
controller.request_termination();
Message::NoOp
}
})
.widget_holder(),
@ -1141,13 +1118,15 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
.widget_holder(),
);
}
ImaginateStatus::Idle | ImaginateStatus::Terminated => widgets.extend_from_slice(&[
ImaginateStatus::Ready | ImaginateStatus::ReadyDone | ImaginateStatus::Terminated | ImaginateStatus::Failed(_) => widgets.extend_from_slice(&[
IconButton::new("Random", 24)
.tooltip("Generate with a new random seed")
.on_update({
let imaginate_node = imaginate_node.clone();
let layer_path = context.layer_path.to_vec();
let controller = controller.clone();
move |_| {
controller.trigger_regenerate();
DocumentMessage::ImaginateRandom {
layer_path: layer_path.clone(),
imaginate_node: imaginate_node.clone(),
@ -1163,7 +1142,9 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
.on_update({
let imaginate_node = imaginate_node.clone();
let layer_path = context.layer_path.to_vec();
let controller = controller.clone();
move |_| {
controller.trigger_regenerate();
DocumentMessage::ImaginateGenerate {
layer_path: layer_path.clone(),
imaginate_node: imaginate_node.clone(),
@ -1175,16 +1156,13 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
WidgetHolder::related_separator(),
TextButton::new("Clear")
.tooltip("Remove generated image from the layer frame")
.disabled(cached_data.is_none())
.disabled(!matches!(imaginate_status, ImaginateStatus::ReadyDone))
.on_update({
let layer_path = context.layer_path.to_vec();
let controller = controller.clone();
move |_| {
DocumentMessage::ImaginateClear {
node_id,
layer_path: layer_path.clone(),
cached_index,
}
.into()
controller.set_status(ImaginateStatus::Ready);
DocumentMessage::ImaginateClear { layer_path: layer_path.clone() }.into()
}
})
.widget_holder(),
@ -1221,9 +1199,11 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
.widget_holder(),
WidgetHolder::unrelated_separator(),
NumberInput::new(Some(seed))
.min(0.)
.int()
.min(-((1u64 << f64::MANTISSA_DIGITS) as f64))
.max((1u64 << f64::MANTISSA_DIGITS) as f64)
.on_update(update_value(move |input: &NumberInput| TaggedValue::F64(input.value.unwrap()), node_id, seed_index))
.mode(NumberInputMode::Increment)
.widget_holder(),
])
}
@ -1231,23 +1211,19 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
LayoutGroup::Row { widgets }.with_tooltip("Seed determines the random outcome, enabling limitless unique variations")
};
// Create the input to the graph using an empty image
let editor_api = std::borrow::Cow::Owned(EditorApi {
image_frame: None,
font_cache: Some(&context.persistent_data.font_cache),
});
// Compute the transform input to the image frame
let image_frame: ImageFrame<Color> = context.executor.compute_input(context.network, &imaginate_node, 0, editor_api).unwrap_or_default();
let transform = image_frame.transform;
let transform = context
.executor
.introspect_node_in_network(context.network, &imaginate_node, |network| network.inputs.first().copied(), |frame: &ImageFrame<Color>| frame.transform)
.unwrap_or_default();
let resolution = {
use document_legacy::document::pick_safe_imaginate_resolution;
use graphene_std::imaginate::pick_safe_imaginate_resolution;
let mut widgets = start_widgets(document_node, node_id, resolution_index, "Resolution", FrontendGraphDataType::Vector, false);
let round = |x: DVec2| {
let (x, y) = pick_safe_imaginate_resolution(x.into());
Some(DVec2::new(x as f64, y as f64))
DVec2::new(x as f64, y as f64)
};
if let &NodeInput::Value {
@ -1256,14 +1232,7 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
} = &document_node.inputs[resolution_index]
{
let dimensions_is_auto = vec2.is_none();
let vec2 = vec2.unwrap_or_else(|| {
let w = transform.transform_vector2(DVec2::new(1., 0.)).length();
let h = transform.transform_vector2(DVec2::new(0., 1.)).length();
let (x, y) = pick_safe_imaginate_resolution((w, h));
DVec2::new(x as f64, y as f64)
});
let vec2 = vec2.unwrap_or_else(|| round([transform.matrix2.x_axis, transform.matrix2.y_axis].map(DVec2::length).into()));
let layer_path = context.layer_path.to_vec();
widgets.extend_from_slice(&[
@ -1308,7 +1277,7 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
.unit(" px")
.disabled(dimensions_is_auto && !transform_not_connected)
.on_update(update_value(
move |number_input: &NumberInput| TaggedValue::OptionalDVec2(round(DVec2::new(number_input.value.unwrap(), vec2.y))),
move |number_input: &NumberInput| TaggedValue::OptionalDVec2(Some(round(DVec2::new(number_input.value.unwrap(), vec2.y)))),
node_id,
resolution_index,
))
@ -1321,7 +1290,7 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
.unit(" px")
.disabled(dimensions_is_auto && !transform_not_connected)
.on_update(update_value(
move |number_input: &NumberInput| TaggedValue::OptionalDVec2(round(DVec2::new(vec2.x, number_input.value.unwrap()))),
move |number_input: &NumberInput| TaggedValue::OptionalDVec2(Some(round(DVec2::new(vec2.x, number_input.value.unwrap())))),
node_id,
resolution_index,
))
@ -1538,8 +1507,6 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
layout.extend_from_slice(&[improve_faces, tiling]);
layout
*/
todo!()
}
fn unknown_node_properties(document_node: &DocumentNode) -> Vec<LayoutGroup> {

View file

@ -1,10 +1,8 @@
use super::utility_types::ImaginateServerStatus;
use crate::messages::portfolio::document::utility_types::clipboards::Clipboard;
use crate::messages::prelude::*;
use document_legacy::LayerId;
use graph_craft::document::NodeId;
use graph_craft::imaginate_input::ImaginateStatus;
use graphene_core::text::Font;
use serde::{Deserialize, Serialize};
@ -57,24 +55,9 @@ pub enum PortfolioMessage {
is_default: bool,
},
ImaginateCheckServerStatus,
ImaginateSetGeneratingStatus {
document_id: u64,
layer_path: Vec<LayerId>,
node_path: Vec<NodeId>,
percent: Option<f64>,
status: ImaginateStatus,
},
ImaginateSetImageData {
document_id: u64,
layer_path: Vec<LayerId>,
node_path: Vec<NodeId>,
image_data: Vec<u8>,
width: u32,
height: u32,
},
ImaginateSetServerStatus {
status: ImaginateServerStatus,
},
ImaginatePollServerStatus,
ImaginatePreferences,
ImaginateServerHostname,
Import,
LoadDocumentResources {
document_id: u64,

View file

@ -7,9 +7,7 @@ use crate::messages::dialog::simple_dialogs;
use crate::messages::frontend::utility_types::FrontendDocumentDetails;
use crate::messages::layout::utility_types::layout_widget::PropertyHolder;
use crate::messages::layout::utility_types::misc::LayoutTarget;
use crate::messages::portfolio::document::node_graph::IMAGINATE_NODE;
use crate::messages::portfolio::document::utility_types::clipboards::{Clipboard, CopyBufferEntry, INTERNAL_CLIPBOARD_COUNT};
use crate::messages::portfolio::utility_types::ImaginateServerStatus;
use crate::messages::prelude::*;
use crate::messages::tool::utility_types::{HintData, HintGroup};
use crate::node_graph_executor::NodeGraphExecutor;
@ -19,7 +17,6 @@ use document_legacy::layers::style::RenderData;
use document_legacy::Operation as DocumentOperation;
use graph_craft::document::value::TaggedValue;
use graph_craft::document::{NodeId, NodeInput};
use graphene_core::raster::Image;
use graphene_core::text::Font;
#[derive(Debug, Default)]
@ -218,70 +215,35 @@ impl MessageHandler<PortfolioMessage, (&InputPreprocessorMessageHandler, &Prefer
self.executor.update_font_cache(self.persistent_data.font_cache.clone());
}
PortfolioMessage::ImaginateCheckServerStatus => {
self.persistent_data.imaginate_server_status = ImaginateServerStatus::Checking;
responses.add(FrontendMessage::TriggerImaginateCheckServerStatus {
hostname: preferences.imaginate_server_hostname.clone(),
});
responses.add(PropertiesPanelMessage::ResendActiveProperties);
}
PortfolioMessage::ImaginateSetGeneratingStatus {
document_id,
layer_path,
node_path,
percent,
status,
} => {
let get = |name: &str| IMAGINATE_NODE.inputs.iter().position(|input| input.name == name).unwrap_or_else(|| panic!("Input {name} not found"));
if let Some(percentage) = percent {
responses.add(PortfolioMessage::DocumentPassMessage {
document_id,
message: NodeGraphMessage::SetQualifiedInputValue {
layer_path: layer_path.clone(),
node_path: node_path.clone(),
input_index: get("Percent Complete"),
value: TaggedValue::F64(percentage),
let server_status = self.persistent_data.imaginate.server_status().clone();
self.persistent_data.imaginate.poll_server_check();
#[cfg(target_arch = "wasm32")]
if let Some(fut) = self.persistent_data.imaginate.initiate_server_check() {
future_executor::spawn(async move {
let () = fut.await;
use wasm_bindgen::prelude::*;
#[wasm_bindgen(module = "/../frontend/src/wasm-communication/editor.ts")]
extern "C" {
#[wasm_bindgen(js_name = injectImaginatePollServerStatus)]
fn inject();
}
.into(),
});
inject();
})
}
if &server_status != self.persistent_data.imaginate.server_status() {
responses.add(PropertiesPanelMessage::ResendActiveProperties);
}
responses.add(PortfolioMessage::DocumentPassMessage {
document_id,
message: NodeGraphMessage::SetQualifiedInputValue {
layer_path,
node_path,
input_index: get("Status"),
value: TaggedValue::ImaginateStatus(status),
}
.into(),
});
}
PortfolioMessage::ImaginateSetImageData {
document_id,
layer_path,
node_path,
image_data,
width,
height,
} => {
let get = |name: &str| IMAGINATE_NODE.inputs.iter().position(|input| input.name == name).unwrap_or_else(|| panic!("Input {name} not found"));
let image = Image::from_image_data(&image_data, width, height);
responses.add(PortfolioMessage::DocumentPassMessage {
document_id,
message: NodeGraphMessage::SetQualifiedInputValue {
layer_path,
node_path,
input_index: get("Cached Data"),
value: TaggedValue::RcImage(Some(std::sync::Arc::new(image))),
}
.into(),
});
}
PortfolioMessage::ImaginateSetServerStatus { status } => {
self.persistent_data.imaginate_server_status = status;
PortfolioMessage::ImaginatePollServerStatus => {
self.persistent_data.imaginate.poll_server_check();
responses.add(PropertiesPanelMessage::ResendActiveProperties);
}
PortfolioMessage::ImaginatePreferences => self.executor.update_imaginate_preferences(preferences.get_imaginate_preferences()),
PortfolioMessage::ImaginateServerHostname => {
info!("setting imaginate persistent data");
self.persistent_data.imaginate.set_host_name(&preferences.imaginate_server_hostname);
}
PortfolioMessage::Import => {
// This portfolio message wraps the frontend message so it can be listed as an action, which isn't possible for frontend messages
if self.active_document().is_some() {
@ -461,7 +423,6 @@ impl MessageHandler<PortfolioMessage, (&InputPreprocessorMessageHandler, &Prefer
(document_id, &mut self.documents),
layer_path,
(input_image_data, size),
imaginate_node_path,
(preferences, &self.persistent_data),
responses,
);

View file

@ -1,29 +1,11 @@
use graphene_std::text::FontCache;
use graphene_std::{imaginate::ImaginatePersistentData, text::FontCache};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug)]
#[derive(Debug, Default)]
pub struct PersistentData {
pub font_cache: FontCache,
pub imaginate_server_status: ImaginateServerStatus,
}
impl Default for PersistentData {
fn default() -> Self {
Self {
font_cache: Default::default(),
imaginate_server_status: ImaginateServerStatus::Unknown,
}
}
}
#[derive(PartialEq, Eq, Clone, Copy, Default, Debug, Serialize, Deserialize, specta::Type)]
pub enum ImaginateServerStatus {
#[default]
Unknown,
Checking,
Unavailable,
Connected,
pub imaginate: ImaginatePersistentData,
}
#[derive(PartialEq, Eq, Clone, Copy, Default, Debug, Serialize, Deserialize)]

View file

@ -1,5 +1,6 @@
use crate::messages::input_mapper::key_mapping::MappingVariant;
use crate::messages::prelude::*;
use graph_craft::imaginate_input::ImaginatePreferences;
use serde::{Deserialize, Serialize};
@ -10,10 +11,19 @@ pub struct PreferencesMessageHandler {
pub zoom_with_scroll: bool,
}
impl PreferencesMessageHandler {
pub fn get_imaginate_preferences(&self) -> ImaginatePreferences {
ImaginatePreferences {
host_name: self.imaginate_server_hostname.clone(),
}
}
}
impl Default for PreferencesMessageHandler {
fn default() -> Self {
let ImaginatePreferences { host_name } = Default::default();
Self {
imaginate_server_hostname: "http://localhost:7860/".into(),
imaginate_server_hostname: host_name,
imaginate_refresh_frequency: 1.,
zoom_with_scroll: matches!(MappingVariant::default(), MappingVariant::ZoomWithScroll),
}
@ -28,9 +38,9 @@ impl MessageHandler<PreferencesMessage, ()> for PreferencesMessageHandler {
if let Ok(deserialized_preferences) = serde_json::from_str::<PreferencesMessageHandler>(&preferences) {
*self = deserialized_preferences;
if self.imaginate_server_hostname != Self::default().imaginate_server_hostname {
responses.add(PortfolioMessage::ImaginateCheckServerStatus);
}
responses.add(PortfolioMessage::ImaginateServerHostname);
responses.add(PortfolioMessage::ImaginateCheckServerStatus);
responses.add(PortfolioMessage::ImaginatePreferences);
}
}
PreferencesMessage::ResetToDefaults => {
@ -43,6 +53,7 @@ impl MessageHandler<PreferencesMessage, ()> for PreferencesMessageHandler {
PreferencesMessage::ImaginateRefreshFrequency { seconds } => {
self.imaginate_refresh_frequency = seconds;
responses.add(PortfolioMessage::ImaginateCheckServerStatus);
responses.add(PortfolioMessage::ImaginatePreferences);
}
PreferencesMessage::ImaginateServerHostname { hostname } => {
let initial = hostname.clone();
@ -55,7 +66,9 @@ impl MessageHandler<PreferencesMessage, ()> for PreferencesMessageHandler {
}
self.imaginate_server_hostname = hostname;
responses.add(PortfolioMessage::ImaginateServerHostname);
responses.add(PortfolioMessage::ImaginateCheckServerStatus);
responses.add(PortfolioMessage::ImaginatePreferences);
}
PreferencesMessage::ModifyLayout { zoom_with_scroll } => {
self.zoom_with_scroll = zoom_with_scroll;

View file

@ -10,16 +10,16 @@ use document_legacy::{LayerId, Operation};
use graph_craft::document::value::TaggedValue;
use graph_craft::document::{generate_uuid, DocumentNodeImplementation, NodeId, NodeNetwork};
use graph_craft::graphene_compiler::Compiler;
use graph_craft::imaginate_input::ImaginatePreferences;
use graph_craft::{concrete, Type, TypeDescriptor};
use graphene_core::application_io::ApplicationIo;
use graphene_core::application_io::{ApplicationIo, NodeGraphUpdateMessage, NodeGraphUpdateSender};
use graphene_core::raster::{Image, ImageFrame};
use graphene_core::renderer::{SvgSegment, SvgSegmentList};
use graphene_core::text::FontCache;
use graphene_core::vector::style::ViewMode;
use graphene_core::{Color, SurfaceFrame, SurfaceId};
use graphene_std::wasm_application_io::WasmApplicationIo;
use graphene_std::wasm_application_io::WasmEditorApi;
use graphene_std::wasm_application_io::{WasmApplicationIo, WasmEditorApi};
use interpreted_executor::dynamic_executor::DynamicExecutor;
use glam::{DAffine2, DVec2};
@ -33,8 +33,9 @@ pub struct NodeRuntime {
pub(crate) executor: DynamicExecutor,
font_cache: FontCache,
receiver: Receiver<NodeRuntimeMessage>,
sender: Sender<GenerationResponse>,
sender: InternalNodeGraphUpdateSender,
wasm_io: Option<WasmApplicationIo>,
imaginate_preferences: ImaginatePreferences,
pub(crate) thumbnails: HashMap<LayerId, HashMap<NodeId, SvgSegmentList>>,
canvas_cache: HashMap<Vec<LayerId>, SurfaceId>,
}
@ -42,6 +43,7 @@ pub struct NodeRuntime {
enum NodeRuntimeMessage {
GenerationRequest(GenerationRequest),
FontCacheUpdate(FontCache),
ImaginatePreferencesUpdate(ImaginatePreferences),
}
pub(crate) struct GenerationRequest {
@ -50,6 +52,7 @@ pub(crate) struct GenerationRequest {
path: Vec<LayerId>,
image_frame: Option<ImageFrame<Color>>,
}
pub(crate) struct GenerationResponse {
generation_id: u64,
result: Result<TaggedValue, String>,
@ -57,18 +60,38 @@ pub(crate) struct GenerationResponse {
new_thumbnails: HashMap<LayerId, HashMap<NodeId, SvgSegmentList>>,
}
enum NodeGraphUpdate {
GenerationResponse(GenerationResponse),
NodeGraphUpdateMessage(NodeGraphUpdateMessage),
}
struct InternalNodeGraphUpdateSender(Sender<NodeGraphUpdate>);
impl InternalNodeGraphUpdateSender {
fn send_generation_response(&self, response: GenerationResponse) {
self.0.send(NodeGraphUpdate::GenerationResponse(response)).expect("Failed to send response")
}
}
impl NodeGraphUpdateSender for InternalNodeGraphUpdateSender {
fn send(&self, message: NodeGraphUpdateMessage) {
self.0.send(NodeGraphUpdate::NodeGraphUpdateMessage(message)).expect("Failed to send response")
}
}
thread_local! {
pub(crate) static NODE_RUNTIME: Rc<RefCell<Option<NodeRuntime>>> = Rc::new(RefCell::new(None));
}
impl NodeRuntime {
fn new(receiver: Receiver<NodeRuntimeMessage>, sender: Sender<GenerationResponse>) -> Self {
fn new(receiver: Receiver<NodeRuntimeMessage>, sender: Sender<NodeGraphUpdate>) -> Self {
let executor = DynamicExecutor::default();
Self {
executor,
receiver,
sender,
sender: InternalNodeGraphUpdateSender(sender),
font_cache: FontCache::default(),
imaginate_preferences: Default::default(),
thumbnails: Default::default(),
wasm_io: None,
canvas_cache: Default::default(),
@ -80,13 +103,14 @@ impl NodeRuntime {
// This should be avoided in the future.
requests.reverse();
requests.dedup_by_key(|x| match x {
NodeRuntimeMessage::FontCacheUpdate(_) => None,
NodeRuntimeMessage::GenerationRequest(x) => Some(x.path.clone()),
_ => None,
});
requests.reverse();
for request in requests {
match request {
NodeRuntimeMessage::FontCacheUpdate(font_cache) => self.font_cache = font_cache,
NodeRuntimeMessage::ImaginatePreferencesUpdate(preferences) => self.imaginate_preferences = preferences,
NodeRuntimeMessage::GenerationRequest(GenerationRequest {
generation_id,
graph,
@ -105,7 +129,7 @@ impl NodeRuntime {
updates: responses,
new_thumbnails: self.thumbnails.clone(),
};
self.sender.send(response).expect("Failed to send response");
self.sender.send_generation_response(response);
}
}
}
@ -134,6 +158,8 @@ impl NodeRuntime {
font_cache: &self.font_cache,
image_frame,
application_io: &self.wasm_io.as_ref().unwrap(),
node_graph_message_sender: &self.sender,
imaginate_preferences: &self.imaginate_preferences,
};
// We assume only one output
@ -240,7 +266,7 @@ pub async fn run_node_graph() {
#[derive(Debug)]
pub struct NodeGraphExecutor {
sender: Sender<NodeRuntimeMessage>,
receiver: Receiver<GenerationResponse>,
receiver: Receiver<NodeGraphUpdate>,
// TODO: This is a memory leak since layers are never removed
pub(crate) last_output_type: HashMap<Vec<LayerId>, Option<Type>>,
pub(crate) thumbnails: HashMap<LayerId, HashMap<NodeId, SvgSegmentList>>,
@ -294,10 +320,31 @@ impl NodeGraphExecutor {
self.sender.send(NodeRuntimeMessage::FontCacheUpdate(font_cache)).expect("Failed to send font cache update");
}
pub fn update_imaginate_preferences(&self, imaginate_preferences: ImaginatePreferences) {
self.sender
.send(NodeRuntimeMessage::ImaginatePreferencesUpdate(imaginate_preferences))
.expect("Failed to send imaginate preferences");
}
pub fn previous_output_type(&self, path: &[LayerId]) -> Option<Type> {
self.last_output_type.get(path).cloned().flatten()
}
pub fn introspect_node_in_network<T: std::any::Any + core::fmt::Debug, U, F1: FnOnce(&NodeNetwork) -> Option<NodeId>, F2: FnOnce(&T) -> U>(
&mut self,
network: &NodeNetwork,
node_path: &[NodeId],
find_node: F1,
extract_data: F2,
) -> Option<U> {
let wrapping_document_node = network.nodes.get(node_path.last()?)?;
let DocumentNodeImplementation::Network(wrapped_network) = &wrapping_document_node.implementation else { return None; };
let introspection_node = find_node(&wrapped_network)?;
let introspection = self.introspect_node(&[node_path, &[introspection_node]].concat())?;
let downcasted: &T = <dyn std::any::Any>::downcast_ref(introspection.as_ref())?;
Some(extract_data(downcasted))
}
/// Encodes an image into a format using the image crate
fn encode_img(image: Image<Color>, resize: Option<DVec2>, format: image::ImageOutputFormat) -> Result<(Vec<u8>, (u32, u32)), String> {
use image::{ImageBuffer, Rgba};
@ -334,13 +381,12 @@ impl NodeGraphExecutor {
})
}
/// Evaluates a node graph, computing either the Imaginate node or the entire graph
/// Evaluates a node graph, computing the entire graph
pub fn submit_node_graph_evaluation(
&mut self,
(document_id, documents): (u64, &mut HashMap<u64, DocumentMessageHandler>),
layer_path: Vec<LayerId>,
(input_image_data, (width, height)): (Vec<u8>, (u32, u32)),
_imaginate_node: Option<Vec<NodeId>>,
_persistent_data: (&PreferencesMessageHandler, &PersistentData),
_responses: &mut VecDeque<Message>,
) -> Result<(), String> {
@ -365,11 +411,6 @@ impl NodeGraphExecutor {
let transform = DAffine2::IDENTITY;
let image_frame = ImageFrame { image, transform };
// Special execution path for generating Imaginate (as generation requires IO from outside node graph)
/*if let Some(imaginate_node) = imaginate_node {
responses.add(self.generate_imaginate(network, imaginate_node, (document, document_id), layer_path, editor_api, persistent_data)?);
return Ok(());
}*/
// Execute the node graph
let generation_id = self.queue_execution(network, Some(image_frame), layer_path.clone());
@ -381,26 +422,32 @@ impl NodeGraphExecutor {
pub fn poll_node_graph_evaluation(&mut self, responses: &mut VecDeque<Message>) -> Result<(), String> {
let results = self.receiver.try_iter().collect::<Vec<_>>();
for response in results {
let GenerationResponse {
generation_id,
result,
updates,
new_thumbnails,
} = response;
self.thumbnails = new_thumbnails;
let node_graph_output = result.map_err(|e| format!("Node graph evaluation failed: {:?}", e))?;
let execution_context = self.futures.remove(&generation_id).ok_or_else(|| "Invalid generation ID".to_string())?;
responses.extend(updates);
self.process_node_graph_output(node_graph_output, execution_context.layer_path.clone(), responses, execution_context.document_id)?;
responses.add(DocumentMessage::LayerChanged {
affected_layer_path: execution_context.layer_path,
});
responses.add(DocumentMessage::RenderDocument);
responses.add(ArtboardMessage::RenderArtboards);
responses.add(DocumentMessage::DocumentStructureChanged);
responses.add(BroadcastEvent::DocumentIsDirty);
responses.add(DocumentMessage::DirtyRenderDocument);
responses.add(DocumentMessage::Overlays(OverlaysMessage::Rerender));
match response {
NodeGraphUpdate::GenerationResponse(GenerationResponse {
generation_id,
result,
updates,
new_thumbnails,
}) => {
self.thumbnails = new_thumbnails;
let node_graph_output = result.map_err(|e| format!("Node graph evaluation failed: {:?}", e))?;
let execution_context = self.futures.remove(&generation_id).ok_or_else(|| "Invalid generation ID".to_string())?;
responses.extend(updates);
self.process_node_graph_output(node_graph_output, execution_context.layer_path.clone(), responses, execution_context.document_id)?;
responses.add(DocumentMessage::LayerChanged {
affected_layer_path: execution_context.layer_path,
});
responses.add(DocumentMessage::RenderDocument);
responses.add(ArtboardMessage::RenderArtboards);
responses.add(DocumentMessage::DocumentStructureChanged);
responses.add(BroadcastEvent::DocumentIsDirty);
responses.add(DocumentMessage::DirtyRenderDocument);
responses.add(DocumentMessage::Overlays(OverlaysMessage::Rerender));
}
NodeGraphUpdate::NodeGraphUpdateMessage(NodeGraphUpdateMessage::ImaginateStatusUpdate) => {
responses.add(DocumentMessage::PropertiesPanel(PropertiesPanelMessage::ResendActiveProperties))
}
}
}
Ok(())
}

View file

@ -3,7 +3,6 @@
import {writable} from "svelte/store";
import { downloadFileText, downloadFileBlob, upload, downloadFileURL } from "@graphite/utility-functions/files";
import { imaginateGenerate, imaginateCheckConnection, imaginateTerminate, updateBackendImage } from "@graphite/utility-functions/imaginate";
import { extractPixelData, imageToPNG, rasterizeSVG, rasterizeSVGCanvas } from "@graphite/utility-functions/rasterization";
import { type Editor } from "@graphite/wasm-communication/editor";
import {
@ -13,8 +12,6 @@ import {
TriggerDownloadRaster,
TriggerDownloadTextFile,
TriggerImaginateCheckServerStatus,
TriggerImaginateGenerate,
TriggerImaginateTerminate,
TriggerImport,
TriggerOpenDocument,
TriggerRasterizeRegionBelowLayer,
@ -85,37 +82,6 @@ export function createPortfolioState(editor: Editor) {
// Have the browser download the file to the user's disk
downloadFileBlob(name, blob);
});
editor.subscriptions.subscribeJsMessage(TriggerImaginateCheckServerStatus, async (triggerImaginateCheckServerStatus) => {
const { hostname } = triggerImaginateCheckServerStatus;
imaginateCheckConnection(hostname, editor);
});
editor.subscriptions.subscribeJsMessage(TriggerImaginateGenerate, async (triggerImaginateGenerate) => {
const { documentId, layerPath, nodePath, hostname, refreshFrequency, baseImage, maskImage, maskPaintMode, maskBlurPx, maskFillContent, parameters } = triggerImaginateGenerate;
// Handle img2img mode
let image: Blob | undefined;
if (parameters.denoisingStrength !== undefined && baseImage !== undefined) {
const buffer = new Uint8Array(baseImage.imageData.values()).buffer;
image = new Blob([buffer], { type: baseImage.mime });
updateBackendImage(editor, image, documentId, layerPath, nodePath);
}
// Handle layer mask
let mask: Blob | undefined;
if (maskImage !== undefined) {
// Rasterize the SVG to an image file
mask = await rasterizeSVG(maskImage.svg, maskImage.size[0], maskImage.size[1], "image/png");
}
imaginateGenerate(parameters, image, mask, maskPaintMode, maskBlurPx, maskFillContent, hostname, refreshFrequency, documentId, layerPath, nodePath, editor);
});
editor.subscriptions.subscribeJsMessage(TriggerImaginateTerminate, async (triggerImaginateTerminate) => {
const { documentId, layerPath, nodePath, hostname } = triggerImaginateTerminate;
imaginateTerminate(hostname, documentId, layerPath, nodePath, editor);
});
editor.subscriptions.subscribeJsMessage(UpdateImageData, (updateImageData) => {
updateImageData.imageData.forEach(async (element) => {
const buffer = new Uint8Array(element.imageData.values()).buffer;

View file

@ -1,367 +0,0 @@
/* eslint-disable camelcase */
// import { escapeJSON } from "@graphite/utility-functions/escape";
import { blobToBase64 } from "@graphite/utility-functions/files";
import { type RequestResult, requestWithUploadDownloadProgress } from "@graphite/utility-functions/network";
import { type Editor } from "@graphite/wasm-communication/editor";
import type { XY } from "@graphite/wasm-communication/messages";
import { type ImaginateGenerationParameters } from "@graphite/wasm-communication/messages";
const MAX_POLLING_RETRIES = 4;
const SERVER_STATUS_CHECK_TIMEOUT = 5000;
const PROGRESS_EVERY_N_STEPS = 5;
let timer: NodeJS.Timeout | undefined;
let terminated = false;
let generatingAbortRequest: XMLHttpRequest | undefined;
let pollingAbortController = new AbortController();
let statusAbortController = new AbortController();
// PUBLICLY CALLABLE FUNCTIONS
export async function imaginateGenerate(
parameters: ImaginateGenerationParameters,
image: Blob | undefined,
mask: Blob | undefined,
maskPaintMode: string,
maskBlurPx: number,
maskFillContent: string,
hostname: string,
refreshFrequency: number,
documentId: bigint,
layerPath: BigUint64Array,
nodePath: BigUint64Array,
editor: Editor
): Promise<void> {
// Ignore a request to generate a new image while another is already being generated
if (generatingAbortRequest !== undefined) return;
terminated = false;
// Immediately set the progress to 0% so the backend knows to update its layout
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, 0, "Beginning");
// Initiate a request to the computation server
const discloseUploadingProgress = (progress: number): void => {
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, progress * 100, "Uploading");
};
const { uploaded, result, xhr } = await generate(discloseUploadingProgress, hostname, image, mask, maskPaintMode, maskBlurPx, maskFillContent, parameters);
generatingAbortRequest = xhr;
try {
// Wait until the request is fully uploaded, which could be slow if the img2img source is large and the user is on a slow connection
await uploaded;
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, 0, "Generating");
// Begin polling for updates to the in-progress image generation at the specified interval
// Don't poll if the chosen interval is 0, or if the chosen sampling method does not support polling
if (refreshFrequency > 0) {
const interval = Math.max(refreshFrequency * 1000, 500);
scheduleNextPollingUpdate(interval, Date.now(), 0, editor, hostname, documentId, layerPath, nodePath, parameters.resolution);
}
// Wait for the final image to be returned by the initial request containing either the full image or the last frame if it was terminated by the user
const { body, status } = await result;
if (status < 200 || status > 299) {
throw new Error(`Request to server failed to return a 200-level status code (${status})`);
}
// Extract the final image from the response and convert it to a data blob
const base64Data = JSON.parse(body)?.images?.[0] as string | undefined;
const base64 = typeof base64Data === "string" && base64Data.length > 0 ? `data:image/png;base64,${base64Data}` : undefined;
if (!base64) throw new Error("Could not read final image result from server response");
const blob = await (await fetch(base64)).blob();
// Send the backend an updated status
const percent = terminated ? undefined : 100;
const newStatus = terminated ? "Terminated" : "Idle";
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, percent, newStatus);
// Send the backend a blob URL for the final image
updateBackendImage(editor, blob, documentId, layerPath, nodePath);
} catch {
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, undefined, "Terminated");
await imaginateCheckConnection(hostname, editor);
}
abortAndResetGenerating();
abortAndResetPolling();
}
export async function imaginateTerminate(hostname: string, documentId: bigint, layerPath: BigUint64Array, nodePath: BigUint64Array, editor: Editor): Promise<void> {
terminated = true;
abortAndResetPolling();
try {
await terminate(hostname);
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, undefined, "Terminating");
} catch {
abortAndResetGenerating();
abortAndResetPolling();
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, undefined, "Terminated");
await imaginateCheckConnection(hostname, editor);
}
}
export async function imaginateCheckConnection(hostname: string, editor: Editor): Promise<void> {
const serverReached = await checkConnection(hostname);
editor.instance.setImaginateServerStatus(serverReached);
}
// Converts the blob image into a list of pixels using an invisible canvas.
export async function updateBackendImage(editor: Editor, blob: Blob, documentId: bigint, layerPath: BigUint64Array, nodePath: BigUint64Array): Promise<void> {
const image = await createImageBitmap(blob);
const canvas = document.createElement("canvas");
canvas.width = image.width;
canvas.height = image.height;
const ctx = canvas.getContext("2d");
if (!ctx) throw new Error("Could not create canvas context");
ctx.drawImage(image, 0, 0);
// Send the backend the blob data to be stored persistently in the layer
const imageData = ctx.getImageData(0, 0, image.width, image.height);
const u8Array = new Uint8Array(imageData.data);
editor.instance.setImaginateImageData(documentId, layerPath, nodePath, u8Array, imageData.width, imageData.height);
}
// ABORTING AND RESETTING HELPERS
function abortAndResetGenerating(): void {
generatingAbortRequest?.abort();
generatingAbortRequest = undefined;
}
function abortAndResetPolling(): void {
pollingAbortController.abort();
pollingAbortController = new AbortController();
clearTimeout(timer);
}
// POLLING IMPLEMENTATION DETAILS
function scheduleNextPollingUpdate(
interval: number,
timeoutBegan: number,
pollingRetries: number,
editor: Editor,
hostname: string,
documentId: bigint,
layerPath: BigUint64Array,
nodePath: BigUint64Array,
resolution: XY
): void {
// Pick a future time that keeps to the user-requested interval if possible, but on slower connections will go as fast as possible without overlapping itself
const nextPollTimeGoal = timeoutBegan + interval;
const timeFromNow = Math.max(0, nextPollTimeGoal - Date.now());
timer = setTimeout(async () => {
const nextTimeoutBegan = Date.now();
try {
const [blob, percentComplete] = await pollImage(hostname);
// After waiting for the polling result back from the server, if during that intervening time the user has terminated the generation, exit so we don't overwrite that terminated status
if (terminated) return;
if (blob) updateBackendImage(editor, blob, documentId, layerPath, nodePath);
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, percentComplete, "Generating");
scheduleNextPollingUpdate(interval, nextTimeoutBegan, 0, editor, hostname, documentId, layerPath, nodePath, resolution);
} catch {
if (generatingAbortRequest === undefined) return;
if (pollingRetries + 1 > MAX_POLLING_RETRIES) {
abortAndResetGenerating();
abortAndResetPolling();
await imaginateCheckConnection(hostname, editor);
} else {
scheduleNextPollingUpdate(interval, nextTimeoutBegan, pollingRetries + 1, editor, hostname, documentId, layerPath, nodePath, resolution);
}
}
}, timeFromNow);
}
// API COMMUNICATION FUNCTIONS
async function pollImage(hostname: string): Promise<[Blob | undefined, number]> {
// Fetch the percent progress and in-progress image from the API
const result = await fetch(`${hostname}sdapi/v1/progress`, { signal: pollingAbortController.signal, method: "GET" });
const { current_image, progress } = await result.json();
// Convert to a usable format
const progressPercent = progress * 100;
const base64 = typeof current_image === "string" && current_image.length > 0 ? `data:image/png;base64,${current_image}` : undefined;
// Deal with a missing image
if (!base64) {
// The image is not ready yet (because it's only had a few samples since generation began), but we do have a progress percentage
if (!Number.isNaN(progressPercent) && progressPercent >= 0 && progressPercent <= 100) {
return [undefined, progressPercent];
}
// Something else is wrong and the image wasn't provided as expected
return Promise.reject();
}
// The image was provided so we turn it into a data blob
const blob = await (await fetch(base64)).blob();
return [blob, progressPercent];
}
async function generate(
discloseUploadingProgress: (progress: number) => void,
hostname: string,
image: Blob | undefined,
mask: Blob | undefined,
maskPaintMode: string,
maskBlurPx: number,
maskFillContent: string,
parameters: ImaginateGenerationParameters
): Promise<{
uploaded: Promise<void>;
result: Promise<RequestResult>;
xhr?: XMLHttpRequest;
}> {
let body;
let endpoint;
if (image === undefined || parameters.denoisingStrength === undefined) {
endpoint = `${hostname}sdapi/v1/txt2img`;
body = {
// enable_hr: false,
// denoising_strength: 0,
// firstphase_width: 0,
// firstphase_height: 0,
prompt: parameters.prompt,
// styles: [],
seed: Number(parameters.seed),
// subseed: -1,
// subseed_strength: 0,
// seed_resize_from_h: -1,
// seed_resize_from_w: -1,
// batch_size: 1,
// n_iter: 1,
steps: parameters.samples,
cfg_scale: parameters.cfgScale,
width: parameters.resolution.x,
height: parameters.resolution.y,
restore_faces: parameters.restoreFaces,
tiling: parameters.tiling,
negative_prompt: parameters.negativePrompt,
// eta: 0,
// s_churn: 0,
// s_tmax: 0,
// s_tmin: 0,
// s_noise: 1,
override_settings: {
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
},
sampler_index: parameters.samplingMethod,
};
} else {
const sourceImageBase64 = await blobToBase64(image);
const maskImageBase64 = mask ? await blobToBase64(mask) : "";
const maskFillContentIndexes = ["Fill", "Original", "LatentNoise", "LatentNothing"];
const maskFillContentIndexFound = maskFillContentIndexes.indexOf(maskFillContent);
const maskFillContentIndex = maskFillContentIndexFound === -1 ? undefined : maskFillContentIndexFound;
const maskInvert = maskPaintMode === "Inpaint" ? 1 : 0;
endpoint = `${hostname}sdapi/v1/img2img`;
body = {
init_images: [sourceImageBase64],
// resize_mode: 0,
denoising_strength: parameters.denoisingStrength,
mask: mask && maskImageBase64,
mask_blur: mask && maskBlurPx,
inpainting_fill: mask && maskFillContentIndex,
inpaint_full_res: mask && false,
// inpaint_full_res_padding: 0,
inpainting_mask_invert: mask && maskInvert,
prompt: parameters.prompt,
// styles: [],
seed: Number(parameters.seed),
// subseed: -1,
// subseed_strength: 0,
// seed_resize_from_h: -1,
// seed_resize_from_w: -1,
// batch_size: 1,
// n_iter: 1,
steps: parameters.samples,
cfg_scale: parameters.cfgScale,
width: parameters.resolution.x,
height: parameters.resolution.y,
restore_faces: parameters.restoreFaces,
tiling: parameters.tiling,
negative_prompt: parameters.negativePrompt,
// eta: 0,
// s_churn: 0,
// s_tmax: 0,
// s_tmin: 0,
// s_noise: 1,
override_settings: {
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
img2img_fix_steps: true,
},
sampler_index: parameters.samplingMethod,
// include_init_images: false,
};
}
// Prepare a promise that will resolve after the outbound request upload is complete
let uploadedResolve: () => void;
let uploadedReject: () => void;
const uploaded = new Promise<void>((resolve, reject): void => {
uploadedResolve = resolve;
uploadedReject = reject;
});
// Fire off the request and, once the outbound request upload is complete, resolve the promise we defined above
const uploadProgress = (progress: number): void => {
if (progress < 1) {
discloseUploadingProgress(progress);
} else {
uploadedResolve();
}
};
const [result, xhr] = requestWithUploadDownloadProgress(endpoint, "POST", JSON.stringify(body), uploadProgress, abortAndResetPolling);
result.catch(() => uploadedReject());
// Return the promise that resolves when the request upload is complete, the promise that resolves when the response download is complete, and the XHR so it can be aborted
return { uploaded, result, xhr };
}
async function terminate(hostname: string): Promise<void> {
await fetch(`${hostname}sdapi/v1/interrupt`, { method: "POST" });
}
async function checkConnection(hostname: string): Promise<boolean> {
statusAbortController.abort();
statusAbortController = new AbortController();
const timeout = setTimeout(() => statusAbortController.abort(), SERVER_STATUS_CHECK_TIMEOUT);
try {
// Intentionally misuse this API endpoint by using it just to check for a code 200 response, regardless of what the result is
const { status } = await fetch(`${hostname}sdapi/v1/progress?skip_current_image=true`, { signal: statusAbortController.signal, method: "GET" });
// This code means the server has indeed responded and the endpoint exists (otherwise it would be 404)
if (status === 200) {
clearTimeout(timeout);
return true;
}
} catch {
// Do nothing here
}
return false;
}

View file

@ -95,3 +95,7 @@ export function createEditor() {
subscriptions,
};
}
export function injectImaginatePollServerStatus() {
window["editorInstance"]?.injectImaginatePollServerStatus()
}

View file

@ -516,7 +516,7 @@ export class TriggerCopyToClipboardBlobUrl extends JsMessage {
export class TriggerDownloadBlobUrl extends JsMessage {
readonly layerName!: string;
readonly blobUrl!: string;
}
@ -537,85 +537,6 @@ export class TriggerDownloadTextFile extends JsMessage {
readonly name!: string;
}
export class TriggerImaginateCheckServerStatus extends JsMessage {
readonly hostname!: string;
}
export class TriggerImaginateGenerate extends JsMessage {
@Type(() => ImaginateGenerationParameters)
readonly parameters!: ImaginateGenerationParameters;
@Type(() => ImaginateBaseImage)
readonly baseImage!: ImaginateBaseImage | undefined;
@Type(() => ImaginateMaskImage)
readonly maskImage: ImaginateMaskImage | undefined;
readonly maskPaintMode!: string;
readonly maskBlurPx!: number;
readonly maskFillContent!: string;
readonly hostname!: string;
readonly refreshFrequency!: number;
readonly documentId!: bigint;
readonly layerPath!: BigUint64Array;
readonly nodePath!: BigUint64Array;
}
export class ImaginateMaskImage {
readonly svg!: string;
readonly size!: [number, number];
}
export class ImaginateBaseImage {
readonly mime!: string;
readonly imageData!: Uint8Array;
@TupleToVec2
readonly size!: [number, number];
}
export class ImaginateGenerationParameters {
readonly seed!: number;
readonly samples!: number;
readonly samplingMethod!: string;
readonly denoisingStrength!: number | undefined;
readonly cfgScale!: number;
readonly prompt!: string;
readonly negativePrompt!: string;
@BigIntTupleToVec2
readonly resolution!: XY;
readonly restoreFaces!: boolean;
readonly tiling!: boolean;
}
export class TriggerImaginateTerminate extends JsMessage {
readonly documentId!: bigint;
readonly layerPath!: BigUint64Array;
readonly nodePath!: BigUint64Array;
readonly hostname!: string;
}
export class TriggerRasterizeRegionBelowLayer extends JsMessage {
readonly documentId!: bigint;
@ -778,7 +699,7 @@ export class ImaginateImageData {
readonly mime!: string;
readonly imageData!: Uint8Array;
readonly transform!: Float64Array ;
}
@ -1404,9 +1325,6 @@ export const messageMakers: Record<string, MessageMaker> = {
TriggerDownloadRaster,
TriggerDownloadTextFile,
TriggerFontLoad,
TriggerImaginateCheckServerStatus,
TriggerImaginateGenerate,
TriggerImaginateTerminate,
TriggerImport,
TriggerIndexedDbRemoveDocument,
TriggerIndexedDbWriteDocument,

View file

@ -30,9 +30,9 @@ graphene-core = { path = "../../node-graph/gcore", features = [
] }
serde = { version = "1.0", features = ["derive"] }
wasm-bindgen = { version = "=0.2.86" }
serde-wasm-bindgen = "0.4.1"
js-sys = "0.3.55"
wasm-bindgen-futures = "0.4.33"
serde-wasm-bindgen = "0.5.0"
js-sys = "0.3.63"
wasm-bindgen-futures = "0.4.36"
ron = { version = "0.8", optional = true }
bezier-rs = { path = "../../libraries/bezier-rs" }

View file

@ -11,7 +11,7 @@ use editor::application::Editor;
use editor::consts::{FILE_SAVE_SUFFIX, GRAPHITE_DOCUMENT_VERSION};
use editor::messages::input_mapper::utility_types::input_keyboard::ModifierKeys;
use editor::messages::input_mapper::utility_types::input_mouse::{EditorMouseState, ScrollDelta, ViewportBounds};
use editor::messages::portfolio::utility_types::{ImaginateServerStatus, Platform};
use editor::messages::portfolio::utility_types::Platform;
use editor::messages::prelude::*;
use graph_craft::document::NodeId;
use graphene_core::raster::color::Color;
@ -586,63 +586,6 @@ impl JsEditorHandle {
}
}
/// Sends the blob URL generated by JS to the Imaginate layer in the respective document
#[wasm_bindgen(js_name = setImaginateImageData)]
pub fn set_imaginate_image_data(&self, document_id: u64, layer_path: Vec<LayerId>, node_path: Vec<NodeId>, image_data: Vec<u8>, width: u32, height: u32) {
let message = PortfolioMessage::ImaginateSetImageData {
document_id,
node_path,
layer_path,
image_data,
width,
height,
};
self.dispatch(message);
}
/// Notifies the Imaginate layer of a new percentage of completion and whether or not it's currently generating
#[wasm_bindgen(js_name = setImaginateGeneratingStatus)]
pub fn set_imaginate_generating_status(&self, document_id: u64, layer_path: Vec<LayerId>, node_path: Vec<NodeId>, percent: Option<f64>, status: String) {
use graph_craft::imaginate_input::ImaginateStatus;
let status = match status.as_str() {
"Idle" => ImaginateStatus::Idle,
"Beginning" => ImaginateStatus::Beginning,
"Uploading" => ImaginateStatus::Uploading(percent.expect("Percent needs to be supplied to set ImaginateStatus::Uploading")),
"Generating" => ImaginateStatus::Generating,
"Terminating" => ImaginateStatus::Terminating,
"Terminated" => ImaginateStatus::Terminated,
_ => panic!("Invalid string from JS for ImaginateStatus, received: {}", status),
};
let percent = if matches!(status, ImaginateStatus::Uploading(_)) { None } else { percent };
let message = PortfolioMessage::ImaginateSetGeneratingStatus {
document_id,
layer_path,
node_path,
percent,
status,
};
self.dispatch(message);
}
/// Notifies the editor that the Imaginate server is available or unavailable
#[wasm_bindgen(js_name = setImaginateServerStatus)]
pub fn set_imaginate_server_status(&self, available: bool) {
let message: Message = match available {
true => PortfolioMessage::ImaginateSetServerStatus {
status: ImaginateServerStatus::Connected,
}
.into(),
false => PortfolioMessage::ImaginateSetServerStatus {
status: ImaginateServerStatus::Unavailable,
}
.into(),
};
self.dispatch(message);
}
/// Sends the blob URL generated by JS to the Imaginate layer in the respective document
#[wasm_bindgen(js_name = renderGraphUsingRasterizedRegionBelowLayer)]
pub fn render_graph_using_rasterized_region_below_layer(
@ -793,6 +736,11 @@ impl JsEditorHandle {
});
frontend_messages.unwrap().unwrap_or_default()
}
#[wasm_bindgen(js_name = injectImaginatePollServerStatus)]
pub fn inject_imaginate_poll_server_status(&self) {
self.dispatch(PortfolioMessage::ImaginatePollServerStatus);
}
}
// Needed to make JsEditorHandle functions pub to Rust.

View file

@ -12,4 +12,9 @@ futures = "0.3.25"
log = "0.4"
[target.wasm32-unknown-unknown.dependencies]
wasm-rs-async-executor = {version = "0.9.0", features = ["cooperative-browser", "debug", "requestIdleCallback"] }
wasm-rs-async-executor = { version = "0.9.0", features = [
"cooperative-browser",
"debug",
"requestIdleCallback",
] }
wasm-bindgen-futures = "0.4.36"

View file

@ -22,3 +22,8 @@ pub fn block_on<F: Future + 'static>(future: F) -> F::Output {
#[cfg(not(target_arch = "wasm32"))]
futures::executor::block_on(future)
}
#[cfg(target_arch = "wasm32")]
pub fn spawn<F: Future<Output = ()> + 'static>(future: F) {
wasm_bindgen_futures::spawn_local(future);
}

View file

@ -112,10 +112,25 @@ impl<T: ApplicationIo> ApplicationIo for &T {
}
}
#[derive(Debug, Clone)]
pub enum NodeGraphUpdateMessage {
ImaginateStatusUpdate,
}
pub trait NodeGraphUpdateSender {
fn send(&self, message: NodeGraphUpdateMessage);
}
pub trait GetImaginatePreferences {
fn get_host_name(&self) -> &str;
}
pub struct EditorApi<'a, Io> {
pub image_frame: Option<ImageFrame<Color>>,
pub font_cache: &'a FontCache,
pub application_io: &'a Io,
pub node_graph_message_sender: &'a dyn NodeGraphUpdateSender,
pub imaginate_preferences: &'a dyn GetImaginatePreferences,
}
impl<'a, Io> Clone for EditorApi<'a, Io> {
@ -124,6 +139,8 @@ impl<'a, Io> Clone for EditorApi<'a, Io> {
image_frame: self.image_frame.clone(),
font_cache: self.font_cache,
application_io: self.application_io,
node_graph_message_sender: self.node_graph_message_sender,
imaginate_preferences: self.imaginate_preferences,
}
}
}

View file

@ -441,7 +441,7 @@ impl NodeNetwork {
}
/// Check if the specified node id is connected to the output
pub fn connected_to_output(&self, target_node_id: NodeId, ignore_imaginate: bool) -> bool {
pub fn connected_to_output(&self, target_node_id: NodeId) -> bool {
// If the node is the output then return true
if self.outputs.iter().any(|&NodeOutput { node_id, .. }| node_id == target_node_id) {
return true;
@ -454,11 +454,6 @@ impl NodeNetwork {
already_visited.extend(self.outputs.iter().map(|output| output.node_id));
while let Some(node) = stack.pop() {
// Skip the imaginate node inputs
if ignore_imaginate && node.name == "Imaginate" {
continue;
}
for input in &node.inputs {
if let &NodeInput::Node { node_id: ref_id, .. } = input {
// Skip if already viewed
@ -680,7 +675,7 @@ impl NodeNetwork {
let mut dummy_input = NodeInput::ShortCircut(concrete!(()));
std::mem::swap(&mut dummy_input, input);
if let NodeInput::Value { tagged_value, exposed } = dummy_input {
if let NodeInput::Value { mut tagged_value, exposed } = dummy_input {
let value_node_id = gen_id();
let merged_node_id = map_ids(id, value_node_id);
let path = if let Some(mut new_path) = node.path.clone() {

View file

@ -1,6 +1,6 @@
use super::DocumentNode;
use crate::graphene_compiler::Any;
pub use crate::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateStatus};
pub use crate::imaginate_input::{ImaginateCache, ImaginateController, ImaginateMaskStartingFill, ImaginateSamplingMethod};
use crate::proto::{Any as DAny, FutureAny};
use graphene_core::raster::brush_cache::BrushCache;
@ -27,7 +27,7 @@ pub enum TaggedValue {
OptionalDVec2(Option<DVec2>),
DAffine2(DAffine2),
Image(graphene_core::raster::Image<Color>),
RcImage(Option<Arc<graphene_core::raster::Image<Color>>>),
ImaginateCache(ImaginateCache),
ImageFrame(graphene_core::raster::ImageFrame<Color>),
Color(graphene_core::raster::color::Color),
Subpaths(Vec<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>),
@ -36,7 +36,7 @@ pub enum TaggedValue {
LuminanceCalculation(LuminanceCalculation),
ImaginateSamplingMethod(ImaginateSamplingMethod),
ImaginateMaskStartingFill(ImaginateMaskStartingFill),
ImaginateStatus(ImaginateStatus),
ImaginateController(ImaginateController),
LayerPath(Option<Vec<u64>>),
VectorData(graphene_core::vector::VectorData),
Fill(graphene_core::vector::style::Fill),
@ -83,7 +83,7 @@ impl Hash for TaggedValue {
}
Self::DAffine2(m) => m.to_cols_array().iter().for_each(|x| x.to_bits().hash(state)),
Self::Image(i) => i.hash(state),
Self::RcImage(i) => i.hash(state),
Self::ImaginateCache(i) => i.hash(state),
Self::Color(c) => c.hash(state),
Self::Subpaths(s) => s.iter().for_each(|subpath| subpath.hash(state)),
Self::RcSubpath(s) => s.hash(state),
@ -91,7 +91,7 @@ impl Hash for TaggedValue {
Self::LuminanceCalculation(l) => l.hash(state),
Self::ImaginateSamplingMethod(m) => m.hash(state),
Self::ImaginateMaskStartingFill(f) => f.hash(state),
Self::ImaginateStatus(s) => s.hash(state),
Self::ImaginateController(s) => s.hash(state),
Self::LayerPath(p) => p.hash(state),
Self::ImageFrame(i) => i.hash(state),
Self::VectorData(vector_data) => vector_data.hash(state),
@ -146,7 +146,7 @@ impl<'a> TaggedValue {
TaggedValue::OptionalDVec2(x) => Box::new(x),
TaggedValue::DAffine2(x) => Box::new(x),
TaggedValue::Image(x) => Box::new(x),
TaggedValue::RcImage(x) => Box::new(x),
TaggedValue::ImaginateCache(x) => Box::new(x),
TaggedValue::ImageFrame(x) => Box::new(x),
TaggedValue::Color(x) => Box::new(x),
TaggedValue::Subpaths(x) => Box::new(x),
@ -155,7 +155,7 @@ impl<'a> TaggedValue {
TaggedValue::LuminanceCalculation(x) => Box::new(x),
TaggedValue::ImaginateSamplingMethod(x) => Box::new(x),
TaggedValue::ImaginateMaskStartingFill(x) => Box::new(x),
TaggedValue::ImaginateStatus(x) => Box::new(x),
TaggedValue::ImaginateController(x) => Box::new(x),
TaggedValue::LayerPath(x) => Box::new(x),
TaggedValue::VectorData(x) => Box::new(x),
TaggedValue::Fill(x) => Box::new(x),
@ -210,7 +210,7 @@ impl<'a> TaggedValue {
TaggedValue::DVec2(_) => concrete!(DVec2),
TaggedValue::OptionalDVec2(_) => concrete!(Option<DVec2>),
TaggedValue::Image(_) => concrete!(graphene_core::raster::Image<Color>),
TaggedValue::RcImage(_) => concrete!(Option<Arc<graphene_core::raster::Image<Color>>>),
TaggedValue::ImaginateCache(_) => concrete!(ImaginateCache),
TaggedValue::ImageFrame(_) => concrete!(graphene_core::raster::ImageFrame<Color>),
TaggedValue::Color(_) => concrete!(graphene_core::raster::Color),
TaggedValue::Subpaths(_) => concrete!(Vec<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>),
@ -218,7 +218,7 @@ impl<'a> TaggedValue {
TaggedValue::BlendMode(_) => concrete!(BlendMode),
TaggedValue::ImaginateSamplingMethod(_) => concrete!(ImaginateSamplingMethod),
TaggedValue::ImaginateMaskStartingFill(_) => concrete!(ImaginateMaskStartingFill),
TaggedValue::ImaginateStatus(_) => concrete!(ImaginateStatus),
TaggedValue::ImaginateController(_) => concrete!(ImaginateController),
TaggedValue::LayerPath(_) => concrete!(Option<Vec<u64>>),
TaggedValue::DAffine2(_) => concrete!(DAffine2),
TaggedValue::LuminanceCalculation(_) => concrete!(LuminanceCalculation),
@ -263,7 +263,7 @@ impl<'a> TaggedValue {
x if x == TypeId::of::<DVec2>() => Ok(TaggedValue::DVec2(*downcast(input).unwrap())),
x if x == TypeId::of::<Option<DVec2>>() => Ok(TaggedValue::OptionalDVec2(*downcast(input).unwrap())),
x if x == TypeId::of::<graphene_core::raster::Image<Color>>() => Ok(TaggedValue::Image(*downcast(input).unwrap())),
x if x == TypeId::of::<Option<Arc<graphene_core::raster::Image<Color>>>>() => Ok(TaggedValue::RcImage(*downcast(input).unwrap())),
x if x == TypeId::of::<ImaginateCache>() => Ok(TaggedValue::ImaginateCache(*downcast(input).unwrap())),
x if x == TypeId::of::<graphene_core::raster::ImageFrame<Color>>() => Ok(TaggedValue::ImageFrame(*downcast(input).unwrap())),
x if x == TypeId::of::<graphene_core::raster::Color>() => Ok(TaggedValue::Color(*downcast(input).unwrap())),
x if x == TypeId::of::<Vec<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>>() => Ok(TaggedValue::Subpaths(*downcast(input).unwrap())),
@ -271,7 +271,7 @@ impl<'a> TaggedValue {
x if x == TypeId::of::<BlendMode>() => Ok(TaggedValue::BlendMode(*downcast(input).unwrap())),
x if x == TypeId::of::<ImaginateSamplingMethod>() => Ok(TaggedValue::ImaginateSamplingMethod(*downcast(input).unwrap())),
x if x == TypeId::of::<ImaginateMaskStartingFill>() => Ok(TaggedValue::ImaginateMaskStartingFill(*downcast(input).unwrap())),
x if x == TypeId::of::<ImaginateStatus>() => Ok(TaggedValue::ImaginateStatus(*downcast(input).unwrap())),
x if x == TypeId::of::<ImaginateController>() => Ok(TaggedValue::ImaginateController(*downcast(input).unwrap())),
x if x == TypeId::of::<Option<Vec<u64>>>() => Ok(TaggedValue::LayerPath(*downcast(input).unwrap())),
x if x == TypeId::of::<DAffine2>() => Ok(TaggedValue::DAffine2(*downcast(input).unwrap())),
x if x == TypeId::of::<LuminanceCalculation>() => Ok(TaggedValue::LuminanceCalculation(*downcast(input).unwrap())),

View file

@ -1,50 +1,155 @@
use dyn_any::{DynAny, StaticType};
use glam::DVec2;
use graphene_core::Color;
use std::borrow::Cow;
use std::fmt::Debug;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
};
#[derive(Default, Debug, Clone, Copy, PartialEq, DynAny, specta::Type)]
#[derive(Default, Debug, Clone, DynAny, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateCache(Arc<Mutex<graphene_core::raster::Image<Color>>>);
impl ImaginateCache {
pub fn into_inner(self) -> Arc<Mutex<graphene_core::raster::Image<Color>>> {
self.0
}
}
impl std::cmp::PartialEq for ImaginateCache {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl core::hash::Hash for ImaginateCache {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.0.lock().unwrap().hash(state);
}
}
pub trait ImaginateTerminationHandle: Debug + Send + Sync + 'static {
fn terminate(&self);
}
#[derive(Default, Debug, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct InternalImaginateControl {
status: Mutex<ImaginateStatus>,
trigger_regenerate: AtomicBool,
#[serde(skip)]
termination_sender: Mutex<Option<Box<dyn ImaginateTerminationHandle>>>,
}
#[derive(Debug, Default, Clone, DynAny, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateController(Arc<InternalImaginateControl>);
impl ImaginateController {
pub fn get_status(&self) -> ImaginateStatus {
self.0.status.lock().as_deref().cloned().unwrap_or_default()
}
pub fn set_status(&self, status: ImaginateStatus) {
if let Ok(mut lock) = self.0.status.lock() {
*lock = status
}
}
pub fn take_regenerate_trigger(&self) -> bool {
self.0.trigger_regenerate.swap(false, Ordering::SeqCst)
}
pub fn trigger_regenerate(&self) {
self.0.trigger_regenerate.store(true, Ordering::SeqCst)
}
pub fn request_termination(&self) {
if let Some(handle) = self.0.termination_sender.lock().ok().and_then(|mut lock| lock.take()) {
handle.terminate()
}
}
pub fn set_termination_handle<H: ImaginateTerminationHandle>(&self, handle: Box<H>) {
if let Ok(mut lock) = self.0.termination_sender.lock() {
*lock = Some(handle)
}
}
}
impl std::cmp::PartialEq for ImaginateController {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl core::hash::Hash for ImaginateController {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
core::ptr::hash(Arc::as_ptr(&self.0), state)
}
}
#[derive(Default, Debug, Clone, PartialEq, DynAny, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ImaginateStatus {
#[default]
Idle,
Ready,
ReadyDone,
Beginning,
Uploading(f64),
Generating,
Uploading,
Generating(f64),
Terminating,
Terminated,
Failed(String),
}
impl ImaginateStatus {
pub fn to_text(&self) -> Cow<'static, str> {
match self {
Self::Ready => Cow::Borrowed("Ready"),
Self::ReadyDone => Cow::Borrowed("Done"),
Self::Beginning => Cow::Borrowed("Beginning…"),
Self::Uploading => Cow::Borrowed("Downloading Image…"),
Self::Generating(percent) => Cow::Owned(format!("Generating {percent:.0}%")),
Self::Terminating => Cow::Owned(format!("Terminating…")),
Self::Terminated => Cow::Owned(format!("Terminated")),
Self::Failed(err) => Cow::Owned(format!("Failed: {err}")),
}
}
}
#[allow(clippy::derived_hash_with_manual_eq)]
impl core::hash::Hash for ImaginateStatus {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
Self::Idle => 0.hash(state),
Self::Beginning => 1.hash(state),
Self::Uploading(f) => {
2.hash(state);
f.to_bits().hash(state);
}
Self::Generating => 3.hash(state),
Self::Terminating => 4.hash(state),
Self::Terminated => 5.hash(state),
Self::Ready | Self::ReadyDone | Self::Beginning | Self::Uploading | Self::Terminating | Self::Terminated => (),
Self::Generating(f) => f.to_bits().hash(state),
Self::Failed(err) => err.hash(state),
}
}
}
#[derive(Debug, Clone, PartialEq, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateBaseImage {
pub mime: String,
#[cfg_attr(feature = "serde", serde(rename = "imageData"))]
pub image_data: Vec<u8>,
pub size: DVec2,
#[derive(PartialEq, Eq, Clone, Default, Debug)]
pub enum ImaginateServerStatus {
#[default]
Unknown,
Checking,
Connected,
Failed(String),
Unavailable,
}
#[derive(Debug, Clone, PartialEq, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateMaskImage {
pub svg: String,
pub size: DVec2,
impl ImaginateServerStatus {
pub fn to_text(&self) -> Cow<'static, str> {
match self {
Self::Unknown | Self::Checking => Cow::Borrowed("Checking..."),
Self::Connected => Cow::Borrowed("Connected"),
Self::Failed(err) => Cow::Owned(err.clone()),
Self::Unavailable => Cow::Borrowed("Unavailable"),
}
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@ -180,24 +285,26 @@ impl std::fmt::Display for ImaginateSamplingMethod {
}
}
#[derive(Debug, Clone, PartialEq, specta::Type)]
#[derive(Clone, Debug, PartialEq, Hash, specta::Type)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateGenerationParameters {
pub seed: u64,
pub samples: u32,
/// Use `ImaginateSamplingMethod::api_value()` to generate this string
#[cfg_attr(feature = "serde", serde(rename = "samplingMethod"))]
pub sampling_method: String,
#[cfg_attr(feature = "serde", serde(rename = "denoisingStrength"))]
pub image_creativity: Option<f64>,
#[cfg_attr(feature = "serde", serde(rename = "cfgScale"))]
pub text_guidance: f64,
#[cfg_attr(feature = "serde", serde(rename = "prompt"))]
pub text_prompt: String,
#[cfg_attr(feature = "serde", serde(rename = "negativePrompt"))]
pub negative_prompt: String,
pub resolution: (u32, u32),
#[cfg_attr(feature = "serde", serde(rename = "restoreFaces"))]
pub restore_faces: bool,
pub tiling: bool,
pub struct ImaginatePreferences {
pub host_name: String,
}
impl graphene_core::application_io::GetImaginatePreferences for ImaginatePreferences {
fn get_host_name(&self) -> &str {
&self.host_name
}
}
impl Default for ImaginatePreferences {
fn default() -> Self {
Self {
host_name: "http://localhost:7860/".into(),
}
}
}
unsafe impl dyn_any::StaticType for ImaginatePreferences {
type Static = ImaginatePreferences;
}

View file

@ -9,12 +9,18 @@ license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = ["wasm"]
gpu = ["graphene-core/gpu", "gpu-compiler-bin-wrapper", "compilation-client", "gpu-executor"]
default = ["wasm", "imaginate"]
gpu = [
"graphene-core/gpu",
"gpu-compiler-bin-wrapper",
"compilation-client",
"gpu-executor",
]
vulkan = ["gpu", "vulkan-executor"]
wgpu = ["gpu", "wgpu-executor"]
quantization = ["autoquant"]
wasm = ["wasm-bindgen", "web-sys", "js-sys"]
imaginate = ["image/png", "base64", "js-sys", "web-sys", "wasm-bindgen-futures"]
[dependencies]
@ -37,6 +43,7 @@ compilation-client = { path = "../compilation-client", optional = true }
bytemuck = { version = "1.8" }
tempfile = "3"
image = { version = "*", default-features = false }
base64 = { version = "0.21", optional = true }
dyn-clone = "1.0"
log = "0.4"
@ -48,12 +55,13 @@ glam = { version = "0.22", features = ["serde"] }
node-macro = { path = "../node-macro" }
xxhash-rust = { workspace = true }
serde_json = "1.0.96"
reqwest = { version = "0.11.17", features = ["rustls", "rustls-tls"] }
reqwest = { version = "0.11.18", features = ["rustls", "rustls-tls", "json"] }
futures = "0.3.28"
wasm-bindgen = { version = "0.2.84", optional = true }
js-sys = { version = "0.3.55", optional = true }
js-sys = { version = "0.3.63", optional = true }
wgpu-types = "0.16.0"
wgpu = "0.16.1"
wasm-bindgen-futures = { version = "0.4.36", optional = true }
[dependencies.serde]
version = "1.0"
@ -62,7 +70,7 @@ features = ["derive"]
[dependencies.web-sys]
version = "0.3.4"
version = "0.3.63"
optional = true
features = [
"Window",

View file

@ -76,6 +76,7 @@ impl<_I, _O, S0> DynAnyRefNode<_I, _O, S0> {
Self { node, _i: core::marker::PhantomData }
}
}
pub struct DynAnyInRefNode<I, O, Node> {
node: Node,
_i: PhantomData<(I, O)>,
@ -115,6 +116,10 @@ where
fn reset(&self) {
self.node.reset();
}
fn serialize(&self) -> Option<std::sync::Arc<dyn core::any::Any>> {
self.node.serialize()
}
}
impl<N> FutureWrapperNode<N> {

View file

@ -1,5 +1,3 @@
use std::future::Future;
use crate::Node;
pub struct GetNode;
@ -17,16 +15,3 @@ pub struct PostNode<Body> {
async fn post_node(url: String, body: String) -> reqwest::Response {
reqwest::Client::new().post(url).body(body).send().await.unwrap()
}
#[derive(Clone, Copy, Debug)]
pub struct EvalSyncNode {}
#[node_macro::node_fn(EvalSyncNode)]
fn eval_sync<F: Future + 'input>(future: F) -> F::Output {
let future = futures::future::maybe_done(future);
futures::pin_mut!(future);
match future.as_mut().take_output() {
Some(value) => value,
_ => panic!("Node construction future returned pending"),
}
}

View file

@ -0,0 +1,517 @@
use crate::wasm_application_io::WasmEditorApi;
use core::any::TypeId;
use core::future::Future;
use futures::{future::Either, TryFutureExt};
use glam::DVec2;
use graph_craft::imaginate_input::{ImaginateController, ImaginateMaskStartingFill, ImaginatePreferences, ImaginateSamplingMethod, ImaginateServerStatus, ImaginateStatus, ImaginateTerminationHandle};
use graphene_core::application_io::NodeGraphUpdateMessage;
use graphene_core::raster::{Color, Image, Luma, Pixel};
use image::{DynamicImage, ImageBuffer, ImageOutputFormat};
use reqwest::Url;
const PROGRESS_EVERY_N_STEPS: u32 = 5;
const SDAPI_TEXT_TO_IMAGE: &str = "sdapi/v1/txt2img";
const SDAPI_IMAGE_TO_IMAGE: &str = "sdapi/v1/img2img";
const SDAPI_PROGRESS: &str = "sdapi/v1/progress?skip_current_image=true";
const SDAPI_TERMINATE: &str = "sdapi/v1/interrupt";
fn new_client() -> Result<reqwest::Client, Error> {
reqwest::ClientBuilder::new().build().map_err(Error::ClientBuild)
}
fn parse_url(url: &str) -> Result<Url, Error> {
url.try_into().map_err(|err| Error::UrlParse { text: url.into(), err })
}
fn join_url(base_url: &Url, path: &str) -> Result<Url, Error> {
base_url.join(path).map_err(|err| Error::UrlParse { text: base_url.to_string(), err })
}
fn new_get_request<U: reqwest::IntoUrl>(client: &reqwest::Client, url: U) -> Result<reqwest::Request, Error> {
client.get(url).header("Accept", "*/*").build().map_err(Error::RequestBuild)
}
pub struct ImaginatePersistentData {
pending_server_check: Option<futures::channel::oneshot::Receiver<reqwest::Result<reqwest::Response>>>,
host_name: Url,
client: Option<reqwest::Client>,
server_status: ImaginateServerStatus,
}
impl core::fmt::Debug for ImaginatePersistentData {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct(core::any::type_name::<Self>())
.field("pending_server_check", &self.pending_server_check.is_some())
.field("host_name", &self.host_name)
.field("status", &self.server_status)
.finish()
}
}
impl Default for ImaginatePersistentData {
fn default() -> Self {
let mut status = ImaginateServerStatus::default();
let client = new_client().map_err(|err| status = ImaginateServerStatus::Failed(err.to_string())).ok();
let ImaginatePreferences { host_name } = Default::default();
Self {
pending_server_check: None,
host_name: parse_url(&host_name).unwrap(),
client,
server_status: status,
}
}
}
impl ImaginatePersistentData {
pub fn set_host_name(&mut self, name: &str) {
match parse_url(name) {
Ok(url) => self.host_name = url,
Err(err) => self.server_status = ImaginateServerStatus::Failed(err.to_string()),
}
}
fn initiate_server_check_maybe_fail(&mut self) -> Result<Option<core::pin::Pin<Box<dyn Future<Output = ()> + 'static>>>, Error> {
use futures::future::FutureExt;
let Some(client) = &self.client else { return Ok(None); };
if self.pending_server_check.is_some() {
return Ok(None);
}
self.server_status = ImaginateServerStatus::Checking;
let url = join_url(&self.host_name, SDAPI_PROGRESS)?;
let request = new_get_request(client, url)?;
let (send, recv) = futures::channel::oneshot::channel();
let response_future = client.execute(request).map(move |r| {
let _ = send.send(r);
});
self.pending_server_check = Some(recv);
Ok(Some(Box::pin(response_future)))
}
pub fn initiate_server_check(&mut self) -> Option<core::pin::Pin<Box<dyn Future<Output = ()> + 'static>>> {
match self.initiate_server_check_maybe_fail() {
Ok(f) => f,
Err(err) => {
self.server_status = ImaginateServerStatus::Failed(err.to_string());
None
}
}
}
pub fn poll_server_check(&mut self) {
if let Some(mut check) = self.pending_server_check.take() {
self.server_status = match check.try_recv().map(|r| r.map(|r| r.and_then(reqwest::Response::error_for_status))) {
Ok(Some(Ok(_response))) => ImaginateServerStatus::Connected,
Ok(Some(Err(_))) | Err(_) => ImaginateServerStatus::Unavailable,
Ok(None) => {
self.pending_server_check = Some(check);
ImaginateServerStatus::Checking
}
}
}
}
pub fn server_status(&self) -> &ImaginateServerStatus {
&self.server_status
}
pub fn is_checking(&self) -> bool {
matches!(self.server_status, ImaginateServerStatus::Checking)
}
}
#[derive(Debug)]
struct ImaginateFutureAbortHandle(futures::future::AbortHandle);
impl ImaginateTerminationHandle for ImaginateFutureAbortHandle {
fn terminate(&self) {
self.0.abort()
}
}
#[derive(Debug)]
enum Error {
UrlParse { text: String, err: <&'static str as TryInto<Url>>::Error },
ClientBuild(reqwest::Error),
RequestBuild(reqwest::Error),
Request(reqwest::Error),
ResponseFormat(reqwest::Error),
NoImage,
Base64Decode(base64::DecodeError),
ImageDecode(image::error::ImageError),
ImageEncode(image::error::ImageError),
UnsupportedPixelType(&'static str),
InconsistentImageSize,
Terminated,
TerminationFailed(reqwest::Error),
}
impl core::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
match self {
Self::UrlParse { text, err } => write!(f, "invalid url '{text}' ({err})"),
Self::ClientBuild(err) => write!(f, "failed to create a reqwest client ({err})"),
Self::RequestBuild(err) => write!(f, "failed to create a reqwest request ({err})"),
Self::Request(err) => write!(f, "request failed ({err})"),
Self::ResponseFormat(err) => write!(f, "got an invalid API response ({err})"),
Self::NoImage => write!(f, "got an empty API response"),
Self::Base64Decode(err) => write!(f, "failed to decode base64 encoded image ({err})"),
Self::ImageDecode(err) => write!(f, "failed to decode png image ({err})"),
Self::ImageEncode(err) => write!(f, "failed to encode png image ({err})"),
Self::UnsupportedPixelType(ty) => write!(f, "pixel type `{ty}` not supported for imaginate images"),
Self::InconsistentImageSize => write!(f, "image width and height do not match the image byte size"),
Self::Terminated => write!(f, "imaginate request was terminated by the user"),
Self::TerminationFailed(err) => write!(f, "termination failed ({err})"),
}
}
}
impl std::error::Error for Error {}
#[derive(Default, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
struct ImageResponse {
images: Vec<String>,
}
#[derive(Default, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
struct ProgressResponse {
progress: f64,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct ImaginateTextToImageRequestOverrideSettings {
show_progress_every_n_steps: u32,
}
impl Default for ImaginateTextToImageRequestOverrideSettings {
fn default() -> Self {
Self {
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
}
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct ImaginateImageToImageRequestOverrideSettings {
show_progress_every_n_steps: u32,
img2img_fix_steps: bool,
}
impl Default for ImaginateImageToImageRequestOverrideSettings {
fn default() -> Self {
Self {
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
img2img_fix_steps: true,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct ImaginateTextToImageRequest<'a> {
#[serde(flatten)]
common: ImaginateCommonImageRequest<'a>,
override_settings: ImaginateTextToImageRequestOverrideSettings,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct ImaginateMask {
mask: String,
mask_blur: String,
inpainting_fill: u32,
inpaint_full_res: bool,
inpainting_mask_invert: u32,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct ImaginateImageToImageRequest<'a> {
#[serde(flatten)]
common: ImaginateCommonImageRequest<'a>,
override_settings: ImaginateImageToImageRequestOverrideSettings,
init_images: Vec<String>,
denoising_strength: f64,
#[serde(flatten)]
mask: Option<ImaginateMask>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct ImaginateCommonImageRequest<'a> {
prompt: String,
seed: f64,
steps: u32,
cfg_scale: f64,
width: f64,
height: f64,
restore_faces: bool,
tiling: bool,
negative_prompt: String,
sampler_index: &'a str,
}
#[cfg(feature = "imaginate")]
pub async fn imaginate<'a, P: Pixel>(
image: Image<P>,
editor_api: impl Future<Output = WasmEditorApi<'a>>,
controller: ImaginateController,
seed: impl Future<Output = f64>,
res: impl Future<Output = Option<DVec2>>,
samples: impl Future<Output = u32>,
sampling_method: impl Future<Output = ImaginateSamplingMethod>,
prompt_guidance: impl Future<Output = f64>,
prompt: impl Future<Output = String>,
negative_prompt: impl Future<Output = String>,
adapt_input_image: impl Future<Output = bool>,
image_creativity: impl Future<Output = f64>,
masking_layer: impl Future<Output = Option<Vec<u64>>>,
inpaint: impl Future<Output = bool>,
mask_blur: impl Future<Output = f64>,
mask_starting_fill: impl Future<Output = ImaginateMaskStartingFill>,
improve_faces: impl Future<Output = bool>,
tiling: impl Future<Output = bool>,
) -> Image<P> {
let WasmEditorApi {
node_graph_message_sender,
imaginate_preferences,
..
} = editor_api.await;
let set_progress = |progress: ImaginateStatus| {
controller.set_status(progress);
node_graph_message_sender.send(NodeGraphUpdateMessage::ImaginateStatusUpdate);
};
let host_name = imaginate_preferences.get_host_name();
imaginate_maybe_fail(
image,
host_name,
set_progress,
&controller,
seed,
res,
samples,
sampling_method,
prompt_guidance,
prompt,
negative_prompt,
adapt_input_image,
image_creativity,
masking_layer,
inpaint,
mask_blur,
mask_starting_fill,
improve_faces,
tiling,
)
.await
.unwrap_or_else(|err| {
match err {
Error::Terminated => {
set_progress(ImaginateStatus::Terminated);
}
err => {
error!("{err}");
set_progress(ImaginateStatus::Failed(err.to_string()));
}
};
Image::empty()
})
}
#[cfg(feature = "imaginate")]
async fn imaginate_maybe_fail<'a, P: Pixel, F: Fn(ImaginateStatus)>(
image: Image<P>,
host_name: &str,
set_progress: F,
controller: &ImaginateController,
seed: impl Future<Output = f64>,
res: impl Future<Output = Option<DVec2>>,
samples: impl Future<Output = u32>,
sampling_method: impl Future<Output = ImaginateSamplingMethod>,
prompt_guidance: impl Future<Output = f64>,
prompt: impl Future<Output = String>,
negative_prompt: impl Future<Output = String>,
adapt_input_image: impl Future<Output = bool>,
image_creativity: impl Future<Output = f64>,
_masking_layer: impl Future<Output = Option<Vec<u64>>>,
_inpaint: impl Future<Output = bool>,
_mask_blur: impl Future<Output = f64>,
_mask_starting_fill: impl Future<Output = ImaginateMaskStartingFill>,
improve_faces: impl Future<Output = bool>,
tiling: impl Future<Output = bool>,
) -> Result<Image<P>, Error> {
set_progress(ImaginateStatus::Beginning);
let base_url: Url = parse_url(host_name)?;
let client = new_client()?;
let sampler_index = sampling_method.await;
let sampler_index = sampler_index.api_value();
let res = res.await.unwrap_or_else(|| {
let (width, height) = pick_safe_imaginate_resolution((image.width as _, image.height as _));
DVec2::new(width as _, height as _)
});
let common_request_data = ImaginateCommonImageRequest {
prompt: prompt.await,
seed: seed.await,
steps: samples.await,
cfg_scale: prompt_guidance.await,
width: res.x,
height: res.y,
restore_faces: improve_faces.await,
tiling: tiling.await,
negative_prompt: negative_prompt.await,
sampler_index,
};
let request_builder = if adapt_input_image.await {
let base64_data = image_to_base64(image)?;
let request_data = ImaginateImageToImageRequest {
common: common_request_data,
override_settings: Default::default(),
init_images: vec![base64_data],
denoising_strength: image_creativity.await * 0.01,
mask: None,
};
let url = join_url(&base_url, SDAPI_IMAGE_TO_IMAGE)?;
client.post(url).json(&request_data)
} else {
let request_data = ImaginateTextToImageRequest {
common: common_request_data,
override_settings: Default::default(),
};
let url = join_url(&base_url, SDAPI_TEXT_TO_IMAGE)?;
client.post(url).json(&request_data)
};
let request = request_builder.header("Accept", "*/*").build().map_err(Error::RequestBuild)?;
let (response_future, abort_handle) = futures::future::abortable(client.execute(request));
controller.set_termination_handle(Box::new(ImaginateFutureAbortHandle(abort_handle)));
let progress_url = join_url(&base_url, SDAPI_PROGRESS)?;
futures::pin_mut!(response_future);
let response = loop {
let progress_request = new_get_request(&client, progress_url.clone())?;
let progress_response_future = client.execute(progress_request).and_then(|response| response.json());
futures::pin_mut!(progress_response_future);
response_future = match futures::future::select(response_future, progress_response_future).await {
Either::Left((response, _)) => break response,
Either::Right((progress, response_future)) => {
if let Ok(ProgressResponse { progress }) = progress {
set_progress(ImaginateStatus::Generating(progress * 100.));
}
response_future
}
};
};
let response = match response {
Ok(response) => response.and_then(reqwest::Response::error_for_status).map_err(Error::Request)?,
Err(_aborted) => {
set_progress(ImaginateStatus::Terminating);
let url = join_url(&base_url, SDAPI_TERMINATE)?;
let request = client.post(url).build().map_err(Error::RequestBuild)?;
// The user probably doesn't really care if the server side was really aborted or if there was an network error.
// So we fool them that the request was terminated if the termination request in reality failed.
let _ = client.execute(request).await.and_then(reqwest::Response::error_for_status).map_err(Error::TerminationFailed)?;
return Err(Error::Terminated);
}
};
set_progress(ImaginateStatus::Uploading);
let ImageResponse { images } = response.json().await.map_err(Error::ResponseFormat)?;
let result = images.into_iter().next().ok_or(Error::NoImage).and_then(base64_to_image)?;
set_progress(ImaginateStatus::ReadyDone);
Ok(result)
}
fn image_to_base64<P: Pixel>(image: Image<P>) -> Result<String, Error> {
use base64::prelude::*;
let Image { width, height, data } = image;
fn cast_with_f32<S: Pixel, D: image::Pixel<Subpixel = f32>>(data: Vec<S>, width: u32, height: u32) -> Result<DynamicImage, Error>
where
DynamicImage: From<ImageBuffer<D, Vec<f32>>>,
{
ImageBuffer::<D, Vec<f32>>::from_raw(width, height, bytemuck::cast_vec(data))
.ok_or(Error::InconsistentImageSize)
.map(Into::into)
}
let image: DynamicImage = match TypeId::of::<P>() {
id if id == TypeId::of::<Color>() => cast_with_f32::<_, image::Rgba<f32>>(data, width, height)?
// we need to do this cast, because png does not support rgba32f
.to_rgba16().into(),
id if id == TypeId::of::<Luma>() => cast_with_f32::<_, image::Luma<f32>>(data, width, height)?
// we need to do this cast, because png does not support luma32f
.to_luma16().into(),
_ => return Err(Error::UnsupportedPixelType(core::any::type_name::<P>())),
};
let mut png_data = std::io::Cursor::new(vec![]);
image.write_to(&mut png_data, ImageOutputFormat::Png).map_err(Error::ImageEncode)?;
Ok(BASE64_STANDARD.encode(png_data.into_inner()))
}
fn base64_to_image<D: AsRef<[u8]>, P: Pixel>(base64_data: D) -> Result<Image<P>, Error> {
use base64::prelude::*;
let png_data = BASE64_STANDARD.decode(base64_data).map_err(Error::Base64Decode)?;
let dyn_image = image::load_from_memory_with_format(&png_data, image::ImageFormat::Png).map_err(Error::ImageDecode)?;
let (width, height) = (dyn_image.width(), dyn_image.height());
let result_data: Vec<P> = match TypeId::of::<P>() {
id if id == TypeId::of::<Color>() => bytemuck::cast_vec(dyn_image.into_rgba32f().into_raw()),
id if id == TypeId::of::<Luma>() => bytemuck::cast_vec(dyn_image.to_luma32f().into_raw()),
_ => return Err(Error::UnsupportedPixelType(core::any::type_name::<P>())),
};
Ok(Image { data: result_data, width, height })
}
pub fn pick_safe_imaginate_resolution((width, height): (f64, f64)) -> (u64, u64) {
const MAX_RESOLUTION: u64 = 1000 * 1000;
// this is the maximum width/height that can be obtained
const MAX_DIMENSION: u64 = (MAX_RESOLUTION / 64) & !63;
// round the resolution to the nearest multiple of 64
let [width, height] = [width, height].map(|c| (c.round().clamp(0., MAX_DIMENSION as _) as u64 + 32).max(64) & !63);
let resolution = width * height;
if resolution > MAX_RESOLUTION {
// scale down the image, so it is smaller than MAX_RESOLUTION
let scale = (MAX_RESOLUTION as f64 / resolution as f64).sqrt();
let [width, height] = [width, height].map(|c| c as f64 * scale);
if width < 64.0 {
// the image is extremely wide
(64, MAX_DIMENSION)
} else if height < 64.0 {
// the image is extremely high
(MAX_DIMENSION, 64)
} else {
// round down to a multiple of 64, so that the resolution still is smaller than MAX_RESOLUTION
let [width, height] = [width, height].map(|c| c as u64 & !63);
(width, height)
}
} else {
(width, height)
}
}

View file

@ -25,3 +25,5 @@ pub mod brush;
#[cfg(feature = "wasm")]
pub mod wasm_application_io;
pub mod imaginate;

View file

@ -1,8 +1,11 @@
use dyn_any::{DynAny, StaticType};
use glam::{DAffine2, DVec2};
use graph_craft::imaginate_input::{ImaginateController, ImaginateMaskStartingFill, ImaginateSamplingMethod};
use graph_craft::proto::DynFuture;
use graphene_core::raster::{Alpha, BlendMode, BlendNode, Image, ImageFrame, Linear, LinearChannel, Luminance, Pixel, RGBMut, Raster, RasterMut, RedGreenBlue, Sample};
use graphene_core::transform::Transform;
use crate::wasm_application_io::WasmEditorApi;
use graphene_core::raster::bbox::{AxisAlignedBbox, Bbox};
use graphene_core::value::CopiedNode;
use graphene_core::{Color, Node};
@ -414,19 +417,74 @@ fn empty_image<_P: Pixel>(transform: DAffine2, color: _P) -> ImageFrame<_P> {
ImageFrame { image, transform }
}
#[derive(Debug, Clone, Copy)]
pub struct ImaginateNode<P, E> {
cached: E,
_p: PhantomData<P>,
macro_rules! generate_imaginate_node {
($($val:ident: $t:ident: $o:ty,)*) => {
pub struct ImaginateNode<P: Pixel, E, C, $($t,)*> {
editor_api: E,
controller: C,
$($val: $t,)*
cache: std::sync::Arc<std::sync::Mutex<Image<P>>>,
}
impl<'e, P: Pixel, E, C, $($t,)*> ImaginateNode<P, E, C, $($t,)*>
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, WasmEditorApi<'e>>>,
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
{
pub fn new(editor_api: E, controller: C, $($val: $t,)* cache: std::sync::Arc<std::sync::Mutex<Image<P>>>) -> Self {
Self { editor_api, controller, $($val,)* cache }
}
}
impl<'i, 'e: 'i, P: Pixel + 'i, E: 'i, C: 'i, $($t: 'i,)*> Node<'i, ImageFrame<P>> for ImaginateNode<P, E, C, $($t,)*>
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, WasmEditorApi<'e>>>,
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
{
type Output = DynFuture<'i, ImageFrame<P>>;
fn eval(&'i self, frame: ImageFrame<P>) -> Self::Output {
let controller = self.controller.eval(());
$(let $val = self.$val.eval(());)*
Box::pin(async move {
let controller: std::pin::Pin<Box<dyn std::future::Future<Output = ImaginateController>>> = controller;
let controller: ImaginateController = controller.await;
if controller.take_regenerate_trigger() {
let editor_api = self.editor_api.eval(());
let image = super::imaginate::imaginate(frame.image, editor_api, controller, $($val,)*).await;
self.cache.lock().unwrap().clone_from(&image);
return ImageFrame {
image,
..frame
}
}
let image = self.cache.lock().unwrap().clone();
ImageFrame {
image,
..frame
}
})
}
}
}
}
#[node_macro::node_fn(ImaginateNode<_P>)]
fn imaginate<_P: Pixel>(image_frame: ImageFrame<_P>, cached: Option<std::sync::Arc<graphene_core::raster::Image<_P>>>) -> ImageFrame<_P> {
let cached_image = cached.map(|mut x| std::sync::Arc::make_mut(&mut x).clone()).unwrap_or(image_frame.image);
ImageFrame {
image: cached_image,
transform: image_frame.transform,
}
generate_imaginate_node! {
seed: Seed: f64,
res: Res: Option<DVec2>,
samples: Samples: u32,
sampling_method: SamplingMethod: ImaginateSamplingMethod,
prompt_guidance: PromptGuidance: f64,
prompt: Prompt: String,
negative_prompt: NegativePrompt: String,
adapt_input_image: AdaptInputImage: bool,
image_creativity: ImageCreativity: f64,
masking_layer: MaskingLayer: Option<Vec<u64>>,
inpaint: Inpaint: bool,
mask_blur: MaskBlur: f64,
mask_starting_fill: MaskStartingFill: ImaginateMaskStartingFill,
improve_faces: ImproveFaces: bool,
tiling: Tiling: bool,
}
#[derive(Debug, Clone, Copy)]

View file

@ -3,7 +3,6 @@ pub mod node_registry;
#[cfg(test)]
mod tests {
use graph_craft::document::value::TaggedValue;
use graphene_core::*;
use std::borrow::Cow;

View file

@ -1,3 +1,4 @@
use graph_craft::imaginate_input::{ImaginateCache, ImaginateController, ImaginateMaskStartingFill, ImaginateSamplingMethod};
use graph_craft::proto::{NodeConstructor, TypeErasedBox};
use graphene_core::ops::IdNode;
use graphene_core::quantization::QuantizationChannels;
@ -444,18 +445,59 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
params: [WasmSurfaceHandleFrame]
),
async_node!(graphene_core::memo::EndLetNode<_>, input: WasmEditorApi, output: SurfaceFrame, params: [SurfaceFrame]),
vec![(
NodeIdentifier::new("graphene_core::memo::RefNode<_, _>"),
|args| {
Box::pin(async move {
let node: DowncastBothNode<Option<WasmEditorApi>, WasmEditorApi> = graphene_std::any::DowncastBothNode::new(args[0].clone());
let node = <graphene_core::memo::RefNode<_, _>>::new(node);
let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
any.into_type_erased()
})
},
NodeIOTypes::new(concrete!(()), concrete!(WasmEditorApi), vec![fn_type!(Option<WasmEditorApi>, WasmEditorApi)]),
)],
vec![
(
NodeIdentifier::new("graphene_core::memo::RefNode<_, _>"),
|args| {
Box::pin(async move {
let node: DowncastBothNode<Option<WasmEditorApi>, WasmEditorApi> = graphene_std::any::DowncastBothNode::new(args[0].clone());
let node = <graphene_core::memo::RefNode<_, _>>::new(node);
let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
any.into_type_erased()
})
},
NodeIOTypes::new(concrete!(()), concrete!(WasmEditorApi), vec![fn_type!(Option<WasmEditorApi>, WasmEditorApi)]),
),
(
NodeIdentifier::new("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
|args: Vec<Arc<graph_craft::proto::NodeContainer>>| {
Box::pin(async move {
use graphene_std::raster::ImaginateNode;
let cache: ImaginateCache = graphene_std::any::input_node(args.last().unwrap().clone()).eval(()).await;
macro_rules! instanciate_imaginate_node {
($($i:expr,)*) => { ImaginateNode::new($(graphene_std::any::input_node(args[$i].clone()),)* cache.into_inner()) };
}
let node: ImaginateNode<Color, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _> = instanciate_imaginate_node!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,);
let any = graphene_std::any::DynAnyNode::new(ValueNode::new(node));
any.into_type_erased()
})
},
NodeIOTypes::new(
concrete!(ImageFrame<Color>),
concrete!(ImageFrame<Color>),
vec![
fn_type!(WasmEditorApi),
fn_type!(ImaginateController),
fn_type!(f64),
fn_type!(Option<DVec2>),
fn_type!(u32),
fn_type!(ImaginateSamplingMethod),
fn_type!(f64),
fn_type!(String),
fn_type!(String),
fn_type!(bool),
fn_type!(f64),
fn_type!(Option<Vec<u64>>),
fn_type!(bool),
fn_type!(f64),
fn_type!(ImaginateMaskStartingFill),
fn_type!(bool),
fn_type!(bool),
fn_type!(ImaginateCache),
],
),
),
],
async_node!(graphene_core::memo::MemoNode<_, _>, input: (), output: Image<Color>, params: [Image<Color>]),
async_node!(graphene_core::memo::MemoNode<_, _>, input: (), output: ImageFrame<Color>, params: [ImageFrame<Color>]),
async_node!(graphene_core::memo::MemoNode<_, _>, input: (), output: QuantizationChannels, params: [QuantizationChannels]),