mirror of
https://github.com/GraphiteEditor/Graphite.git
synced 2025-08-04 13:30:48 +00:00
Fix Imaginate by porting its JS roundtrip code to graph-based async execution in Rust (#1250)
* Create asynchronous rust imaginate node * Make a first imaginate request via rust * Implement parsing of imaginate API result image * Stop refresh timer from affecting imaginate progress requests * Add cargo-about clarification for rustls-webpki * Delete imaginate.ts and all uses of its functions * Add imaginate img2img feature * Fix imaginate random seed button * Fix imaginate ui inferring non-custom resolutions * Fix the imaginate progress indicator * Remove ImaginatePreferences from being compiled into node graph * Regenerate imaginate only when hitting button * Add ability to terminate imaginate requests * Add imaginate server check feature * Do not compile wasm_bindgen bindings in graphite_editor for tests * Address some review suggestions - move wasm futures dependency in editor to the future-executor crate - guard wasm-bindgen in editor behind a `wasm` feature flag - dont make seed number input a slider - remove poll_server_check from process_message function beginning - guard wasm related code behind `cfg(target_arch = "wasm32")` instead of `cfg(test)` - Call the imaginate idle states "Ready" and "Done" instead of "Nothing to do" - Call the imaginate uploading state "Uploading Image" instead of "Uploading Input Image" - Remove the EvalSyncNode * Fix imaginate host name being restored between graphite instances also change the progress status texts a bit. --------- Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
parent
a1c70c4d90
commit
f76b850b9c
35 changed files with 1500 additions and 1326 deletions
710
Cargo.lock
generated
710
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -21,7 +21,12 @@ workarounds = ["ring"]
|
|||
# is the ISC license but test code within the repo is BSD-3-Clause, but is not compiled into the crate when we use it
|
||||
[webpki.clarify]
|
||||
license = "ISC"
|
||||
|
||||
[[webpki.clarify.files]]
|
||||
path = 'LICENSE'
|
||||
path = "LICENSE"
|
||||
checksum = "5b698ca13897be3afdb7174256fa1574f8c6892b8bea1a66dd6469d3fe27885a"
|
||||
|
||||
[rustls-webpki.clarify]
|
||||
license = "ISC"
|
||||
[[rustls-webpki.clarify.files]]
|
||||
path = "LICENSE"
|
||||
checksum = "5b698ca13897be3afdb7174256fa1574f8c6892b8bea1a66dd6469d3fe27885a"
|
||||
|
|
|
@ -983,23 +983,5 @@ pub fn pick_layer_safe_imaginate_resolution(layer: &Layer, render_data: &RenderD
|
|||
let layer_bounds = layer.bounding_transform(render_data);
|
||||
let layer_bounds_size = (layer_bounds.transform_vector2((1., 0.).into()).length(), layer_bounds.transform_vector2((0., 1.).into()).length());
|
||||
|
||||
pick_safe_imaginate_resolution(layer_bounds_size)
|
||||
}
|
||||
|
||||
pub fn pick_safe_imaginate_resolution((width, height): (f64, f64)) -> (u64, u64) {
|
||||
const MAX_RESOLUTION: u64 = 1000 * 1000;
|
||||
|
||||
let mut scale_factor = 1.;
|
||||
|
||||
let round_to_increment = |size: f64| (size / 64.).round() as u64 * 64;
|
||||
|
||||
loop {
|
||||
let possible_solution = (round_to_increment(width * scale_factor), round_to_increment(height * scale_factor));
|
||||
|
||||
if possible_solution.0 * possible_solution.1 <= MAX_RESOLUTION {
|
||||
return possible_solution;
|
||||
}
|
||||
|
||||
scale_factor -= 0.1;
|
||||
}
|
||||
graphene_std::imaginate::pick_safe_imaginate_resolution(layer_bounds_size)
|
||||
}
|
||||
|
|
|
@ -11,11 +11,13 @@ repository = "https://github.com/GraphiteEditor/Graphite"
|
|||
license = "Apache-2.0"
|
||||
|
||||
[features]
|
||||
default = ["wasm"]
|
||||
gpu = ["interpreted-executor/gpu", "graphene-std/gpu", "graphene-core/gpu", "wgpu-executor", "gpu-executor"]
|
||||
quantization = [
|
||||
"graphene-std/quantization",
|
||||
"interpreted-executor/quantization",
|
||||
]
|
||||
wasm = ["wasm-bindgen", "future-executor", "graphene-std/wasm"]
|
||||
|
||||
[dependencies]
|
||||
log = "0.4"
|
||||
|
@ -45,9 +47,12 @@ gpu-executor = { path = "../node-graph/gpu-executor", optional = true }
|
|||
interpreted-executor = { path = "../node-graph/interpreted-executor" }
|
||||
dyn-any = { path = "../libraries/dyn-any" }
|
||||
graphene-core = { path = "../node-graph/gcore" }
|
||||
graphene-std = { path = "../node-graph/gstd", features = ["wasm"] }
|
||||
graphene-std = { path = "../node-graph/gstd" }
|
||||
future-executor = { path = "../node-graph/future-executor", optional = true }
|
||||
num_enum = "0.6.1"
|
||||
|
||||
wasm-bindgen = { version = "0.2.86", optional = true }
|
||||
|
||||
[dependencies.document-legacy]
|
||||
path = "../document-legacy"
|
||||
package = "graphite-document-legacy"
|
||||
|
|
|
@ -9,7 +9,6 @@ use crate::messages::tool::utility_types::HintData;
|
|||
|
||||
use document_legacy::LayerId;
|
||||
use graph_craft::document::NodeId;
|
||||
use graph_craft::imaginate_input::*;
|
||||
use graphene_core::raster::color::Color;
|
||||
use graphene_core::text::Font;
|
||||
|
||||
|
@ -75,40 +74,6 @@ pub enum FrontendMessage {
|
|||
#[serde(rename = "isDefault")]
|
||||
is_default: bool,
|
||||
},
|
||||
TriggerImaginateCheckServerStatus {
|
||||
hostname: String,
|
||||
},
|
||||
TriggerImaginateGenerate {
|
||||
parameters: Box<ImaginateGenerationParameters>,
|
||||
#[serde(rename = "baseImage")]
|
||||
base_image: Option<Box<ImaginateBaseImage>>,
|
||||
#[serde(rename = "maskImage")]
|
||||
mask_image: Option<Box<ImaginateMaskImage>>,
|
||||
#[serde(rename = "maskPaintMode")]
|
||||
mask_paint_mode: ImaginateMaskPaintMode,
|
||||
#[serde(rename = "maskBlurPx")]
|
||||
mask_blur_px: u32,
|
||||
#[serde(rename = "maskFillContent")]
|
||||
imaginate_mask_starting_fill: ImaginateMaskStartingFill,
|
||||
hostname: String,
|
||||
#[serde(rename = "refreshFrequency")]
|
||||
refresh_frequency: f64,
|
||||
#[serde(rename = "documentId")]
|
||||
document_id: u64,
|
||||
#[serde(rename = "layerPath")]
|
||||
layer_path: Vec<LayerId>,
|
||||
#[serde(rename = "nodePath")]
|
||||
node_path: Vec<NodeId>,
|
||||
},
|
||||
TriggerImaginateTerminate {
|
||||
#[serde(rename = "documentId")]
|
||||
document_id: u64,
|
||||
#[serde(rename = "layerPath")]
|
||||
layer_path: Vec<LayerId>,
|
||||
#[serde(rename = "nodePath")]
|
||||
node_path: Vec<NodeId>,
|
||||
hostname: String,
|
||||
},
|
||||
TriggerImport,
|
||||
TriggerIndexedDbRemoveDocument {
|
||||
#[serde(rename = "documentId")]
|
||||
|
|
|
@ -93,8 +93,6 @@ pub enum DocumentMessage {
|
|||
GroupSelectedLayers,
|
||||
ImaginateClear {
|
||||
layer_path: Vec<LayerId>,
|
||||
node_id: NodeId,
|
||||
cached_index: usize,
|
||||
},
|
||||
ImaginateGenerate {
|
||||
layer_path: Vec<LayerId>,
|
||||
|
@ -105,10 +103,6 @@ pub enum DocumentMessage {
|
|||
imaginate_node: Vec<NodeId>,
|
||||
then_generate: bool,
|
||||
},
|
||||
ImaginateTerminate {
|
||||
layer_path: Vec<LayerId>,
|
||||
node_path: Vec<NodeId>,
|
||||
},
|
||||
InputFrameRasterizeRegionBelowLayer {
|
||||
layer_path: Vec<LayerId>,
|
||||
},
|
||||
|
|
|
@ -465,15 +465,7 @@ impl MessageHandler<DocumentMessage, (u64, &InputPreprocessorMessageHandler, &Pe
|
|||
replacement_selected_layers: vec![new_folder_path],
|
||||
});
|
||||
}
|
||||
ImaginateClear {
|
||||
layer_path,
|
||||
node_id,
|
||||
cached_index: input_index,
|
||||
} => {
|
||||
let value = graph_craft::document::value::TaggedValue::RcImage(None);
|
||||
responses.add(NodeGraphMessage::SetInputValue { node_id, input_index, value });
|
||||
responses.add(InputFrameRasterizeRegionBelowLayer { layer_path });
|
||||
}
|
||||
ImaginateClear { layer_path } => responses.add(InputFrameRasterizeRegionBelowLayer { layer_path }),
|
||||
ImaginateGenerate { layer_path, imaginate_node } => {
|
||||
if let Some(message) = self.rasterize_region_below_layer(document_id, layer_path, preferences, persistent_data, Some(imaginate_node)) {
|
||||
responses.add(message);
|
||||
|
@ -484,12 +476,16 @@ impl MessageHandler<DocumentMessage, (u64, &InputPreprocessorMessageHandler, &Pe
|
|||
imaginate_node,
|
||||
then_generate,
|
||||
} => {
|
||||
// Generate a random seed. We only want values between -2^53 and 2^53, because integer values
|
||||
// outside of this range can get rounded in f64
|
||||
let random_bits = generate_uuid();
|
||||
let random_value = ((random_bits >> 11) as f64).copysign(f64::from_bits(random_bits & (1 << 63)));
|
||||
// Set a random seed input
|
||||
responses.add(NodeGraphMessage::SetInputValue {
|
||||
node_id: *imaginate_node.last().unwrap(),
|
||||
// Needs to match the index of the seed parameter in `pub const IMAGINATE_NODE: DocumentNodeType` in `document_node_type.rs`
|
||||
input_index: 1,
|
||||
value: graph_craft::document::value::TaggedValue::F64((generate_uuid() >> 1) as f64),
|
||||
input_index: 3,
|
||||
value: graph_craft::document::value::TaggedValue::F64(random_value),
|
||||
});
|
||||
|
||||
// Generate the image
|
||||
|
@ -497,14 +493,6 @@ impl MessageHandler<DocumentMessage, (u64, &InputPreprocessorMessageHandler, &Pe
|
|||
responses.add(DocumentMessage::ImaginateGenerate { layer_path, imaginate_node });
|
||||
}
|
||||
}
|
||||
ImaginateTerminate { layer_path, node_path } => {
|
||||
responses.add(FrontendMessage::TriggerImaginateTerminate {
|
||||
document_id,
|
||||
layer_path,
|
||||
node_path,
|
||||
hostname: preferences.imaginate_server_hostname.clone(),
|
||||
});
|
||||
}
|
||||
InputFrameRasterizeRegionBelowLayer { layer_path } => {
|
||||
if layer_path.is_empty() {
|
||||
responses.add(NodeGraphMessage::RunDocumentGraph);
|
||||
|
@ -996,7 +984,7 @@ impl DocumentMessageHandler {
|
|||
// Check if we use the "Input Frame" node.
|
||||
// TODO: Remove once rasterization is moved into a node.
|
||||
let input_frame_node_id = node_network.nodes.iter().find(|(_, node)| node.name == "Input Frame").map(|(&id, _)| id);
|
||||
let input_frame_connected_to_graph_output = input_frame_node_id.map_or(false, |target_node_id| node_network.connected_to_output(target_node_id, imaginate_node_path.is_none()));
|
||||
let input_frame_connected_to_graph_output = input_frame_node_id.map_or(false, |target_node_id| node_network.connected_to_output(target_node_id));
|
||||
|
||||
// If the Input Frame node is connected upstream, rasterize the artwork below this layer by calling into JS
|
||||
let response = if input_frame_connected_to_graph_output {
|
||||
|
|
|
@ -459,7 +459,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
|
|||
let input = NodeInput::node(output_node, output_node_connector_index);
|
||||
responses.add(NodeGraphMessage::SetNodeInput { node_id, input_index, input });
|
||||
|
||||
let should_rerender = network.connected_to_output(node_id, true);
|
||||
let should_rerender = network.connected_to_output(node_id);
|
||||
responses.add(NodeGraphMessage::SendGraph { should_rerender });
|
||||
}
|
||||
NodeGraphMessage::Copy => {
|
||||
|
@ -517,7 +517,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
|
|||
|
||||
if let Some(network) = self.get_active_network(document) {
|
||||
// Only generate node graph if one of the selected nodes is connected to the output
|
||||
if self.selected_nodes.iter().any(|&node_id| network.connected_to_output(node_id, true)) {
|
||||
if self.selected_nodes.iter().any(|&node_id| network.connected_to_output(node_id)) {
|
||||
if let Some(layer_path) = self.layer_path.clone() {
|
||||
responses.add(DocumentMessage::InputFrameRasterizeRegionBelowLayer { layer_path });
|
||||
}
|
||||
|
@ -549,7 +549,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
|
|||
}
|
||||
responses.add(NodeGraphMessage::SetNodeInput { node_id, input_index, input });
|
||||
|
||||
let should_rerender = network.connected_to_output(node_id, true);
|
||||
let should_rerender = network.connected_to_output(node_id);
|
||||
responses.add(NodeGraphMessage::SendGraph { should_rerender });
|
||||
}
|
||||
NodeGraphMessage::DoubleClickNode { node } => {
|
||||
|
@ -626,7 +626,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
|
|||
}
|
||||
responses.add(NodeGraphMessage::SetNodeInput { node_id, input_index, input });
|
||||
|
||||
let should_rerender = network.connected_to_output(node_id, true);
|
||||
let should_rerender = network.connected_to_output(node_id);
|
||||
responses.add(NodeGraphMessage::SendGraph { should_rerender });
|
||||
responses.add(PropertiesPanelMessage::ResendActiveProperties);
|
||||
}
|
||||
|
@ -743,7 +743,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
|
|||
let input = NodeInput::Value { tagged_value: value, exposed: false };
|
||||
responses.add(NodeGraphMessage::SetNodeInput { node_id, input_index, input });
|
||||
responses.add(PropertiesPanelMessage::ResendActiveProperties);
|
||||
if (node.name != "Imaginate" || input_index == 0) && network.connected_to_output(node_id, true) {
|
||||
if (node.name != "Imaginate" || input_index == 0) && network.connected_to_output(node_id) {
|
||||
if let Some(layer_path) = self.layer_path.clone() {
|
||||
responses.add(DocumentMessage::InputFrameRasterizeRegionBelowLayer { layer_path });
|
||||
} else {
|
||||
|
@ -780,7 +780,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
|
|||
node.inputs.extend(((node.inputs.len() - 1)..input_index).map(|_| NodeInput::Network(generic!(T))));
|
||||
}
|
||||
node.inputs[input_index] = NodeInput::Value { tagged_value: value, exposed: false };
|
||||
if network.connected_to_output(*node_id, true) {
|
||||
if network.connected_to_output(*node_id) {
|
||||
responses.add(DocumentMessage::InputFrameRasterizeRegionBelowLayer { layer_path });
|
||||
}
|
||||
}
|
||||
|
@ -854,7 +854,7 @@ impl MessageHandler<NodeGraphMessage, (&mut Document, &NodeGraphExecutor, u64)>
|
|||
Self::send_graph(network, executor, &self.layer_path, responses);
|
||||
|
||||
// Only generate node graph if one of the selected nodes is connected to the output
|
||||
if self.selected_nodes.iter().any(|&node_id| network.connected_to_output(node_id, true)) {
|
||||
if self.selected_nodes.iter().any(|&node_id| network.connected_to_output(node_id)) {
|
||||
if let Some(layer_path) = self.layer_path.clone() {
|
||||
responses.add(DocumentMessage::InputFrameRasterizeRegionBelowLayer { layer_path });
|
||||
}
|
||||
|
|
|
@ -1673,12 +1673,63 @@ fn static_nodes() -> Vec<DocumentNodeType> {
|
|||
pub static IMAGINATE_NODE: Lazy<DocumentNodeType> = Lazy::new(|| DocumentNodeType {
|
||||
name: "Imaginate",
|
||||
category: "Image Synthesis",
|
||||
identifier: NodeImplementation::proto("graphene_std::raster::ImaginateNode<_>"),
|
||||
identifier: NodeImplementation::DocumentNode(NodeNetwork {
|
||||
inputs: vec![0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
outputs: vec![NodeOutput::new(1, 0)],
|
||||
nodes: [
|
||||
(
|
||||
0,
|
||||
DocumentNode {
|
||||
name: "Frame Monitor".into(),
|
||||
inputs: vec![NodeInput::Network(concrete!(ImageFrame<Color>))],
|
||||
implementation: DocumentNodeImplementation::proto("graphene_core::memo::MonitorNode<_>"),
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
(
|
||||
1,
|
||||
DocumentNode {
|
||||
name: "Imaginate".into(),
|
||||
inputs: vec![
|
||||
NodeInput::node(0, 0),
|
||||
NodeInput::Network(concrete!(WasmEditorApi)),
|
||||
NodeInput::Network(concrete!(ImaginateController)),
|
||||
NodeInput::Network(concrete!(f64)),
|
||||
NodeInput::Network(concrete!(Option<DVec2>)),
|
||||
NodeInput::Network(concrete!(u32)),
|
||||
NodeInput::Network(concrete!(ImaginateSamplingMethod)),
|
||||
NodeInput::Network(concrete!(f64)),
|
||||
NodeInput::Network(concrete!(String)),
|
||||
NodeInput::Network(concrete!(String)),
|
||||
NodeInput::Network(concrete!(bool)),
|
||||
NodeInput::Network(concrete!(f64)),
|
||||
NodeInput::Network(concrete!(Option<Vec<u64>>)),
|
||||
NodeInput::Network(concrete!(bool)),
|
||||
NodeInput::Network(concrete!(f64)),
|
||||
NodeInput::Network(concrete!(ImaginateMaskStartingFill)),
|
||||
NodeInput::Network(concrete!(bool)),
|
||||
NodeInput::Network(concrete!(bool)),
|
||||
NodeInput::Network(concrete!(ImaginateCache)),
|
||||
],
|
||||
implementation: DocumentNodeImplementation::proto("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
|
||||
..Default::default()
|
||||
},
|
||||
),
|
||||
]
|
||||
.into(),
|
||||
..Default::default()
|
||||
}),
|
||||
inputs: vec![
|
||||
DocumentInputType::value("Input Image", TaggedValue::ImageFrame(ImageFrame::empty()), true),
|
||||
DocumentInputType {
|
||||
name: "Editor Api",
|
||||
data_type: FrontendGraphDataType::General,
|
||||
default: NodeInput::Network(concrete!(WasmEditorApi)),
|
||||
},
|
||||
DocumentInputType::value("Controller", TaggedValue::ImaginateController(Default::default()), false),
|
||||
DocumentInputType::value("Seed", TaggedValue::F64(0.), false), // Remember to keep index used in `ImaginateRandom` updated with this entry's index
|
||||
DocumentInputType::value("Resolution", TaggedValue::OptionalDVec2(None), false),
|
||||
DocumentInputType::value("Samples", TaggedValue::F64(30.), false),
|
||||
DocumentInputType::value("Samples", TaggedValue::U32(30), false),
|
||||
DocumentInputType::value("Sampling Method", TaggedValue::ImaginateSamplingMethod(ImaginateSamplingMethod::EulerA), false),
|
||||
DocumentInputType::value("Prompt Guidance", TaggedValue::F64(7.5), false),
|
||||
DocumentInputType::value("Prompt", TaggedValue::String(String::new()), false),
|
||||
|
@ -1691,10 +1742,7 @@ pub static IMAGINATE_NODE: Lazy<DocumentNodeType> = Lazy::new(|| DocumentNodeTyp
|
|||
DocumentInputType::value("Mask Starting Fill", TaggedValue::ImaginateMaskStartingFill(ImaginateMaskStartingFill::Fill), false),
|
||||
DocumentInputType::value("Improve Faces", TaggedValue::Bool(false), false),
|
||||
DocumentInputType::value("Tiling", TaggedValue::Bool(false), false),
|
||||
// Non-user status (is document input the right way to do this?)
|
||||
DocumentInputType::value("Cached Data", TaggedValue::RcImage(None), false),
|
||||
DocumentInputType::value("Percent Complete", TaggedValue::F64(0.), false),
|
||||
DocumentInputType::value("Status", TaggedValue::ImaginateStatus(ImaginateStatus::Idle), false),
|
||||
DocumentInputType::value("Cache", TaggedValue::ImaginateCache(Default::default()), false),
|
||||
],
|
||||
outputs: vec![DocumentOutputType::new("Image", FrontendGraphDataType::Raster)],
|
||||
properties: node_properties::imaginate_properties,
|
||||
|
|
|
@ -5,9 +5,11 @@ use super::FrontendGraphDataType;
|
|||
use crate::messages::layout::utility_types::widget_prelude::*;
|
||||
use crate::messages::prelude::*;
|
||||
|
||||
use document_legacy::{layers::layer_info::LayerDataTypeDiscriminant, Operation};
|
||||
use graph_craft::concrete;
|
||||
use graph_craft::document::value::TaggedValue;
|
||||
use graph_craft::document::{DocumentNode, NodeId, NodeInput};
|
||||
use graph_craft::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateServerStatus, ImaginateStatus};
|
||||
use graphene_core::raster::{BlendMode, Color, ImageFrame, LuminanceCalculation, RedGreenBlue, RelativeAbsolute, SelectiveColorChoice};
|
||||
use graphene_core::text::Font;
|
||||
use graphene_core::vector::style::{FillType, GradientType, LineCap, LineJoin};
|
||||
|
@ -980,12 +982,17 @@ pub fn node_section_font(document_node: &DocumentNode, node_id: NodeId, _context
|
|||
result
|
||||
}
|
||||
|
||||
pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _context: &mut NodePropertiesContext) -> Vec<LayoutGroup> {
|
||||
/*
|
||||
pub fn imaginate_properties(document_node: &DocumentNode, node_id: NodeId, context: &mut NodePropertiesContext) -> Vec<LayoutGroup> {
|
||||
let imaginate_node = [context.nested_path, &[node_id]].concat();
|
||||
let layer_path = context.layer_path.to_vec();
|
||||
|
||||
let resolve_input = |name: &str| IMAGINATE_NODE.inputs.iter().position(|input| input.name == name).unwrap_or_else(|| panic!("Input {name} not found"));
|
||||
let resolve_input = |name: &str| {
|
||||
super::IMAGINATE_NODE
|
||||
.inputs
|
||||
.iter()
|
||||
.position(|input| input.name == name)
|
||||
.unwrap_or_else(|| panic!("Input {name} not found"))
|
||||
};
|
||||
let seed_index = resolve_input("Seed");
|
||||
let resolution_index = resolve_input("Resolution");
|
||||
let samples_index = resolve_input("Samples");
|
||||
|
@ -1001,22 +1008,12 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
let mask_fill_index = resolve_input("Mask Starting Fill");
|
||||
let faces_index = resolve_input("Improve Faces");
|
||||
let tiling_index = resolve_input("Tiling");
|
||||
let cached_index = resolve_input("Cached Data");
|
||||
|
||||
let cached_value = &document_node.inputs[cached_index];
|
||||
let complete_value = &document_node.inputs[resolve_input("Percent Complete")];
|
||||
let status_value = &document_node.inputs[resolve_input("Status")];
|
||||
let controller = &document_node.inputs[resolve_input("Controller")];
|
||||
|
||||
let server_status = {
|
||||
let status = match &context.persistent_data.imaginate_server_status {
|
||||
ImaginateServerStatus::Unknown => {
|
||||
context.responses.add(PortfolioMessage::ImaginateCheckServerStatus);
|
||||
"Checking..."
|
||||
}
|
||||
ImaginateServerStatus::Checking => "Checking...",
|
||||
ImaginateServerStatus::Unavailable => "Unavailable",
|
||||
ImaginateServerStatus::Connected => "Connected",
|
||||
};
|
||||
let server_status = context.persistent_data.imaginate.server_status();
|
||||
let status_text = server_status.to_text();
|
||||
let mut widgets = vec![
|
||||
WidgetHolder::text_widget("Server"),
|
||||
WidgetHolder::unrelated_separator(),
|
||||
|
@ -1025,14 +1022,14 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
.on_update(|_| DialogMessage::RequestPreferencesDialog.into())
|
||||
.widget_holder(),
|
||||
WidgetHolder::unrelated_separator(),
|
||||
WidgetHolder::bold_text(status),
|
||||
WidgetHolder::bold_text(status_text),
|
||||
WidgetHolder::related_separator(),
|
||||
IconButton::new("Reload", 24)
|
||||
.tooltip("Refresh connection status")
|
||||
.on_update(|_| PortfolioMessage::ImaginateCheckServerStatus.into())
|
||||
.widget_holder(),
|
||||
];
|
||||
if context.persistent_data.imaginate_server_status == ImaginateServerStatus::Unavailable {
|
||||
if let ImaginateServerStatus::Unavailable | ImaginateServerStatus::Failed(_) = server_status {
|
||||
widgets.extend([
|
||||
WidgetHolder::unrelated_separator(),
|
||||
TextButton::new("Server Help")
|
||||
|
@ -1049,15 +1046,11 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
LayoutGroup::Row { widgets }.with_tooltip("Connection status to the server that computes generated images")
|
||||
};
|
||||
|
||||
let &NodeInput::Value {tagged_value: TaggedValue::ImaginateStatus( imaginate_status),..} = status_value else {
|
||||
panic!("Invalid status input")
|
||||
};
|
||||
let NodeInput::Value {tagged_value: TaggedValue::RcImage( cached_data),..} = cached_value else {
|
||||
panic!("Invalid cached image input, received {:?}, index: {}", cached_value, cached_index)
|
||||
};
|
||||
let &NodeInput::Value {tagged_value: TaggedValue::F64( percent_complete),..} = complete_value else {
|
||||
panic!("Invalid percent complete input")
|
||||
let &NodeInput::Value {tagged_value: TaggedValue::ImaginateController(ref controller),..} = controller else {
|
||||
panic!("Invalid output status input")
|
||||
};
|
||||
let imaginate_status = controller.get_status();
|
||||
|
||||
let use_base_image = if let &NodeInput::Value {
|
||||
tagged_value: TaggedValue::Bool(use_base_image),
|
||||
..
|
||||
|
@ -1071,23 +1064,7 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
let transform_not_connected = false;
|
||||
|
||||
let progress = {
|
||||
// Since we don't serialize the status, we need to derive from other state whether the Idle state is actually supposed to be the Terminated state
|
||||
let mut interpreted_status = imaginate_status;
|
||||
if imaginate_status == ImaginateStatus::Idle && cached_data.is_some() && percent_complete > 0. && percent_complete < 100. {
|
||||
interpreted_status = ImaginateStatus::Terminated;
|
||||
}
|
||||
|
||||
let status = match interpreted_status {
|
||||
ImaginateStatus::Idle => match cached_data {
|
||||
Some(_) => "Done".into(),
|
||||
None => "Ready".into(),
|
||||
},
|
||||
ImaginateStatus::Beginning => "Beginning...".into(),
|
||||
ImaginateStatus::Uploading(percent) => format!("Uploading Input Image: {percent:.0}%"),
|
||||
ImaginateStatus::Generating => format!("Generating: {percent_complete:.0}%"),
|
||||
ImaginateStatus::Terminating => "Terminating...".into(),
|
||||
ImaginateStatus::Terminated => format!("{percent_complete:.0}% (Terminated)"),
|
||||
};
|
||||
let status = imaginate_status.to_text();
|
||||
let widgets = vec![
|
||||
WidgetHolder::text_widget("Progress"),
|
||||
WidgetHolder::unrelated_separator(),
|
||||
|
@ -1095,38 +1072,38 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
WidgetHolder::unrelated_separator(), // TODO: which is the width of the Assist area.
|
||||
WidgetHolder::unrelated_separator(), // TODO: Remove these when we have proper entry row formatting that includes room for Assists.
|
||||
WidgetHolder::unrelated_separator(),
|
||||
WidgetHolder::bold_text(status),
|
||||
WidgetHolder::bold_text(status.as_ref()),
|
||||
];
|
||||
LayoutGroup::Row { widgets }.with_tooltip("When generating, the percentage represents how many sampling steps have so far been processed out of the target number")
|
||||
LayoutGroup::Row { widgets }.with_tooltip(match imaginate_status {
|
||||
ImaginateStatus::Failed(_) => status.as_ref(),
|
||||
_ => "When generating, the percentage represents how many sampling steps have so far been processed out of the target number",
|
||||
})
|
||||
};
|
||||
|
||||
let image_controls = {
|
||||
let image_controls: _ = {
|
||||
let mut widgets = vec![WidgetHolder::text_widget("Image"), WidgetHolder::unrelated_separator()];
|
||||
let assist_separators = vec![
|
||||
let assist_separators = [
|
||||
WidgetHolder::unrelated_separator(), // TODO: These three separators add up to 24px,
|
||||
WidgetHolder::unrelated_separator(), // TODO: which is the width of the Assist area.
|
||||
WidgetHolder::unrelated_separator(), // TODO: Remove these when we have proper entry row formatting that includes room for Assists.
|
||||
WidgetHolder::unrelated_separator(),
|
||||
];
|
||||
|
||||
match imaginate_status {
|
||||
ImaginateStatus::Beginning | ImaginateStatus::Uploading(_) => {
|
||||
match &imaginate_status {
|
||||
ImaginateStatus::Beginning | ImaginateStatus::Uploading => {
|
||||
widgets.extend_from_slice(&assist_separators);
|
||||
widgets.push(TextButton::new("Beginning...").tooltip("Sending image generation request to the server").disabled(true).widget_holder());
|
||||
}
|
||||
ImaginateStatus::Generating => {
|
||||
ImaginateStatus::Generating(_) => {
|
||||
widgets.extend_from_slice(&assist_separators);
|
||||
widgets.push(
|
||||
TextButton::new("Terminate")
|
||||
.tooltip("Cancel the in-progress image generation and keep the latest progress")
|
||||
.on_update({
|
||||
let imaginate_node = imaginate_node.clone();
|
||||
let controller = controller.clone();
|
||||
move |_| {
|
||||
DocumentMessage::ImaginateTerminate {
|
||||
layer_path: layer_path.clone(),
|
||||
node_path: imaginate_node.clone(),
|
||||
}
|
||||
.into()
|
||||
controller.request_termination();
|
||||
Message::NoOp
|
||||
}
|
||||
})
|
||||
.widget_holder(),
|
||||
|
@ -1141,13 +1118,15 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
.widget_holder(),
|
||||
);
|
||||
}
|
||||
ImaginateStatus::Idle | ImaginateStatus::Terminated => widgets.extend_from_slice(&[
|
||||
ImaginateStatus::Ready | ImaginateStatus::ReadyDone | ImaginateStatus::Terminated | ImaginateStatus::Failed(_) => widgets.extend_from_slice(&[
|
||||
IconButton::new("Random", 24)
|
||||
.tooltip("Generate with a new random seed")
|
||||
.on_update({
|
||||
let imaginate_node = imaginate_node.clone();
|
||||
let layer_path = context.layer_path.to_vec();
|
||||
let controller = controller.clone();
|
||||
move |_| {
|
||||
controller.trigger_regenerate();
|
||||
DocumentMessage::ImaginateRandom {
|
||||
layer_path: layer_path.clone(),
|
||||
imaginate_node: imaginate_node.clone(),
|
||||
|
@ -1163,7 +1142,9 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
.on_update({
|
||||
let imaginate_node = imaginate_node.clone();
|
||||
let layer_path = context.layer_path.to_vec();
|
||||
let controller = controller.clone();
|
||||
move |_| {
|
||||
controller.trigger_regenerate();
|
||||
DocumentMessage::ImaginateGenerate {
|
||||
layer_path: layer_path.clone(),
|
||||
imaginate_node: imaginate_node.clone(),
|
||||
|
@ -1175,16 +1156,13 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
WidgetHolder::related_separator(),
|
||||
TextButton::new("Clear")
|
||||
.tooltip("Remove generated image from the layer frame")
|
||||
.disabled(cached_data.is_none())
|
||||
.disabled(!matches!(imaginate_status, ImaginateStatus::ReadyDone))
|
||||
.on_update({
|
||||
let layer_path = context.layer_path.to_vec();
|
||||
let controller = controller.clone();
|
||||
move |_| {
|
||||
DocumentMessage::ImaginateClear {
|
||||
node_id,
|
||||
layer_path: layer_path.clone(),
|
||||
cached_index,
|
||||
}
|
||||
.into()
|
||||
controller.set_status(ImaginateStatus::Ready);
|
||||
DocumentMessage::ImaginateClear { layer_path: layer_path.clone() }.into()
|
||||
}
|
||||
})
|
||||
.widget_holder(),
|
||||
|
@ -1221,9 +1199,11 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
.widget_holder(),
|
||||
WidgetHolder::unrelated_separator(),
|
||||
NumberInput::new(Some(seed))
|
||||
.min(0.)
|
||||
.int()
|
||||
.min(-((1u64 << f64::MANTISSA_DIGITS) as f64))
|
||||
.max((1u64 << f64::MANTISSA_DIGITS) as f64)
|
||||
.on_update(update_value(move |input: &NumberInput| TaggedValue::F64(input.value.unwrap()), node_id, seed_index))
|
||||
.mode(NumberInputMode::Increment)
|
||||
.widget_holder(),
|
||||
])
|
||||
}
|
||||
|
@ -1231,23 +1211,19 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
LayoutGroup::Row { widgets }.with_tooltip("Seed determines the random outcome, enabling limitless unique variations")
|
||||
};
|
||||
|
||||
// Create the input to the graph using an empty image
|
||||
let editor_api = std::borrow::Cow::Owned(EditorApi {
|
||||
image_frame: None,
|
||||
font_cache: Some(&context.persistent_data.font_cache),
|
||||
});
|
||||
// Compute the transform input to the image frame
|
||||
let image_frame: ImageFrame<Color> = context.executor.compute_input(context.network, &imaginate_node, 0, editor_api).unwrap_or_default();
|
||||
let transform = image_frame.transform;
|
||||
let transform = context
|
||||
.executor
|
||||
.introspect_node_in_network(context.network, &imaginate_node, |network| network.inputs.first().copied(), |frame: &ImageFrame<Color>| frame.transform)
|
||||
.unwrap_or_default();
|
||||
|
||||
let resolution = {
|
||||
use document_legacy::document::pick_safe_imaginate_resolution;
|
||||
use graphene_std::imaginate::pick_safe_imaginate_resolution;
|
||||
|
||||
let mut widgets = start_widgets(document_node, node_id, resolution_index, "Resolution", FrontendGraphDataType::Vector, false);
|
||||
|
||||
let round = |x: DVec2| {
|
||||
let (x, y) = pick_safe_imaginate_resolution(x.into());
|
||||
Some(DVec2::new(x as f64, y as f64))
|
||||
DVec2::new(x as f64, y as f64)
|
||||
};
|
||||
|
||||
if let &NodeInput::Value {
|
||||
|
@ -1256,14 +1232,7 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
} = &document_node.inputs[resolution_index]
|
||||
{
|
||||
let dimensions_is_auto = vec2.is_none();
|
||||
let vec2 = vec2.unwrap_or_else(|| {
|
||||
let w = transform.transform_vector2(DVec2::new(1., 0.)).length();
|
||||
let h = transform.transform_vector2(DVec2::new(0., 1.)).length();
|
||||
|
||||
let (x, y) = pick_safe_imaginate_resolution((w, h));
|
||||
|
||||
DVec2::new(x as f64, y as f64)
|
||||
});
|
||||
let vec2 = vec2.unwrap_or_else(|| round([transform.matrix2.x_axis, transform.matrix2.y_axis].map(DVec2::length).into()));
|
||||
|
||||
let layer_path = context.layer_path.to_vec();
|
||||
widgets.extend_from_slice(&[
|
||||
|
@ -1308,7 +1277,7 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
.unit(" px")
|
||||
.disabled(dimensions_is_auto && !transform_not_connected)
|
||||
.on_update(update_value(
|
||||
move |number_input: &NumberInput| TaggedValue::OptionalDVec2(round(DVec2::new(number_input.value.unwrap(), vec2.y))),
|
||||
move |number_input: &NumberInput| TaggedValue::OptionalDVec2(Some(round(DVec2::new(number_input.value.unwrap(), vec2.y)))),
|
||||
node_id,
|
||||
resolution_index,
|
||||
))
|
||||
|
@ -1321,7 +1290,7 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
.unit(" px")
|
||||
.disabled(dimensions_is_auto && !transform_not_connected)
|
||||
.on_update(update_value(
|
||||
move |number_input: &NumberInput| TaggedValue::OptionalDVec2(round(DVec2::new(vec2.x, number_input.value.unwrap()))),
|
||||
move |number_input: &NumberInput| TaggedValue::OptionalDVec2(Some(round(DVec2::new(vec2.x, number_input.value.unwrap())))),
|
||||
node_id,
|
||||
resolution_index,
|
||||
))
|
||||
|
@ -1538,8 +1507,6 @@ pub fn imaginate_properties(_document_node: &DocumentNode, _node_id: NodeId, _co
|
|||
layout.extend_from_slice(&[improve_faces, tiling]);
|
||||
|
||||
layout
|
||||
*/
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn unknown_node_properties(document_node: &DocumentNode) -> Vec<LayoutGroup> {
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
use super::utility_types::ImaginateServerStatus;
|
||||
use crate::messages::portfolio::document::utility_types::clipboards::Clipboard;
|
||||
use crate::messages::prelude::*;
|
||||
|
||||
use document_legacy::LayerId;
|
||||
use graph_craft::document::NodeId;
|
||||
use graph_craft::imaginate_input::ImaginateStatus;
|
||||
use graphene_core::text::Font;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -57,24 +55,9 @@ pub enum PortfolioMessage {
|
|||
is_default: bool,
|
||||
},
|
||||
ImaginateCheckServerStatus,
|
||||
ImaginateSetGeneratingStatus {
|
||||
document_id: u64,
|
||||
layer_path: Vec<LayerId>,
|
||||
node_path: Vec<NodeId>,
|
||||
percent: Option<f64>,
|
||||
status: ImaginateStatus,
|
||||
},
|
||||
ImaginateSetImageData {
|
||||
document_id: u64,
|
||||
layer_path: Vec<LayerId>,
|
||||
node_path: Vec<NodeId>,
|
||||
image_data: Vec<u8>,
|
||||
width: u32,
|
||||
height: u32,
|
||||
},
|
||||
ImaginateSetServerStatus {
|
||||
status: ImaginateServerStatus,
|
||||
},
|
||||
ImaginatePollServerStatus,
|
||||
ImaginatePreferences,
|
||||
ImaginateServerHostname,
|
||||
Import,
|
||||
LoadDocumentResources {
|
||||
document_id: u64,
|
||||
|
|
|
@ -7,9 +7,7 @@ use crate::messages::dialog::simple_dialogs;
|
|||
use crate::messages::frontend::utility_types::FrontendDocumentDetails;
|
||||
use crate::messages::layout::utility_types::layout_widget::PropertyHolder;
|
||||
use crate::messages::layout::utility_types::misc::LayoutTarget;
|
||||
use crate::messages::portfolio::document::node_graph::IMAGINATE_NODE;
|
||||
use crate::messages::portfolio::document::utility_types::clipboards::{Clipboard, CopyBufferEntry, INTERNAL_CLIPBOARD_COUNT};
|
||||
use crate::messages::portfolio::utility_types::ImaginateServerStatus;
|
||||
use crate::messages::prelude::*;
|
||||
use crate::messages::tool::utility_types::{HintData, HintGroup};
|
||||
use crate::node_graph_executor::NodeGraphExecutor;
|
||||
|
@ -19,7 +17,6 @@ use document_legacy::layers::style::RenderData;
|
|||
use document_legacy::Operation as DocumentOperation;
|
||||
use graph_craft::document::value::TaggedValue;
|
||||
use graph_craft::document::{NodeId, NodeInput};
|
||||
use graphene_core::raster::Image;
|
||||
use graphene_core::text::Font;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
|
@ -218,70 +215,35 @@ impl MessageHandler<PortfolioMessage, (&InputPreprocessorMessageHandler, &Prefer
|
|||
self.executor.update_font_cache(self.persistent_data.font_cache.clone());
|
||||
}
|
||||
PortfolioMessage::ImaginateCheckServerStatus => {
|
||||
self.persistent_data.imaginate_server_status = ImaginateServerStatus::Checking;
|
||||
responses.add(FrontendMessage::TriggerImaginateCheckServerStatus {
|
||||
hostname: preferences.imaginate_server_hostname.clone(),
|
||||
});
|
||||
responses.add(PropertiesPanelMessage::ResendActiveProperties);
|
||||
}
|
||||
PortfolioMessage::ImaginateSetGeneratingStatus {
|
||||
document_id,
|
||||
layer_path,
|
||||
node_path,
|
||||
percent,
|
||||
status,
|
||||
} => {
|
||||
let get = |name: &str| IMAGINATE_NODE.inputs.iter().position(|input| input.name == name).unwrap_or_else(|| panic!("Input {name} not found"));
|
||||
if let Some(percentage) = percent {
|
||||
responses.add(PortfolioMessage::DocumentPassMessage {
|
||||
document_id,
|
||||
message: NodeGraphMessage::SetQualifiedInputValue {
|
||||
layer_path: layer_path.clone(),
|
||||
node_path: node_path.clone(),
|
||||
input_index: get("Percent Complete"),
|
||||
value: TaggedValue::F64(percentage),
|
||||
let server_status = self.persistent_data.imaginate.server_status().clone();
|
||||
self.persistent_data.imaginate.poll_server_check();
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
if let Some(fut) = self.persistent_data.imaginate.initiate_server_check() {
|
||||
future_executor::spawn(async move {
|
||||
let () = fut.await;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen(module = "/../frontend/src/wasm-communication/editor.ts")]
|
||||
extern "C" {
|
||||
#[wasm_bindgen(js_name = injectImaginatePollServerStatus)]
|
||||
fn inject();
|
||||
}
|
||||
.into(),
|
||||
});
|
||||
inject();
|
||||
})
|
||||
}
|
||||
if &server_status != self.persistent_data.imaginate.server_status() {
|
||||
responses.add(PropertiesPanelMessage::ResendActiveProperties);
|
||||
}
|
||||
|
||||
responses.add(PortfolioMessage::DocumentPassMessage {
|
||||
document_id,
|
||||
message: NodeGraphMessage::SetQualifiedInputValue {
|
||||
layer_path,
|
||||
node_path,
|
||||
input_index: get("Status"),
|
||||
value: TaggedValue::ImaginateStatus(status),
|
||||
}
|
||||
.into(),
|
||||
});
|
||||
}
|
||||
PortfolioMessage::ImaginateSetImageData {
|
||||
document_id,
|
||||
layer_path,
|
||||
node_path,
|
||||
image_data,
|
||||
width,
|
||||
height,
|
||||
} => {
|
||||
let get = |name: &str| IMAGINATE_NODE.inputs.iter().position(|input| input.name == name).unwrap_or_else(|| panic!("Input {name} not found"));
|
||||
|
||||
let image = Image::from_image_data(&image_data, width, height);
|
||||
responses.add(PortfolioMessage::DocumentPassMessage {
|
||||
document_id,
|
||||
message: NodeGraphMessage::SetQualifiedInputValue {
|
||||
layer_path,
|
||||
node_path,
|
||||
input_index: get("Cached Data"),
|
||||
value: TaggedValue::RcImage(Some(std::sync::Arc::new(image))),
|
||||
}
|
||||
.into(),
|
||||
});
|
||||
}
|
||||
PortfolioMessage::ImaginateSetServerStatus { status } => {
|
||||
self.persistent_data.imaginate_server_status = status;
|
||||
PortfolioMessage::ImaginatePollServerStatus => {
|
||||
self.persistent_data.imaginate.poll_server_check();
|
||||
responses.add(PropertiesPanelMessage::ResendActiveProperties);
|
||||
}
|
||||
PortfolioMessage::ImaginatePreferences => self.executor.update_imaginate_preferences(preferences.get_imaginate_preferences()),
|
||||
PortfolioMessage::ImaginateServerHostname => {
|
||||
info!("setting imaginate persistent data");
|
||||
self.persistent_data.imaginate.set_host_name(&preferences.imaginate_server_hostname);
|
||||
}
|
||||
PortfolioMessage::Import => {
|
||||
// This portfolio message wraps the frontend message so it can be listed as an action, which isn't possible for frontend messages
|
||||
if self.active_document().is_some() {
|
||||
|
@ -461,7 +423,6 @@ impl MessageHandler<PortfolioMessage, (&InputPreprocessorMessageHandler, &Prefer
|
|||
(document_id, &mut self.documents),
|
||||
layer_path,
|
||||
(input_image_data, size),
|
||||
imaginate_node_path,
|
||||
(preferences, &self.persistent_data),
|
||||
responses,
|
||||
);
|
||||
|
|
|
@ -1,29 +1,11 @@
|
|||
use graphene_std::text::FontCache;
|
||||
use graphene_std::{imaginate::ImaginatePersistentData, text::FontCache};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PersistentData {
|
||||
pub font_cache: FontCache,
|
||||
pub imaginate_server_status: ImaginateServerStatus,
|
||||
}
|
||||
|
||||
impl Default for PersistentData {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
font_cache: Default::default(),
|
||||
imaginate_server_status: ImaginateServerStatus::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, Default, Debug, Serialize, Deserialize, specta::Type)]
|
||||
pub enum ImaginateServerStatus {
|
||||
#[default]
|
||||
Unknown,
|
||||
Checking,
|
||||
Unavailable,
|
||||
Connected,
|
||||
pub imaginate: ImaginatePersistentData,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, Default, Debug, Serialize, Deserialize)]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::messages::input_mapper::key_mapping::MappingVariant;
|
||||
use crate::messages::prelude::*;
|
||||
use graph_craft::imaginate_input::ImaginatePreferences;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
@ -10,10 +11,19 @@ pub struct PreferencesMessageHandler {
|
|||
pub zoom_with_scroll: bool,
|
||||
}
|
||||
|
||||
impl PreferencesMessageHandler {
|
||||
pub fn get_imaginate_preferences(&self) -> ImaginatePreferences {
|
||||
ImaginatePreferences {
|
||||
host_name: self.imaginate_server_hostname.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PreferencesMessageHandler {
|
||||
fn default() -> Self {
|
||||
let ImaginatePreferences { host_name } = Default::default();
|
||||
Self {
|
||||
imaginate_server_hostname: "http://localhost:7860/".into(),
|
||||
imaginate_server_hostname: host_name,
|
||||
imaginate_refresh_frequency: 1.,
|
||||
zoom_with_scroll: matches!(MappingVariant::default(), MappingVariant::ZoomWithScroll),
|
||||
}
|
||||
|
@ -28,9 +38,9 @@ impl MessageHandler<PreferencesMessage, ()> for PreferencesMessageHandler {
|
|||
if let Ok(deserialized_preferences) = serde_json::from_str::<PreferencesMessageHandler>(&preferences) {
|
||||
*self = deserialized_preferences;
|
||||
|
||||
if self.imaginate_server_hostname != Self::default().imaginate_server_hostname {
|
||||
responses.add(PortfolioMessage::ImaginateCheckServerStatus);
|
||||
}
|
||||
responses.add(PortfolioMessage::ImaginateServerHostname);
|
||||
responses.add(PortfolioMessage::ImaginateCheckServerStatus);
|
||||
responses.add(PortfolioMessage::ImaginatePreferences);
|
||||
}
|
||||
}
|
||||
PreferencesMessage::ResetToDefaults => {
|
||||
|
@ -43,6 +53,7 @@ impl MessageHandler<PreferencesMessage, ()> for PreferencesMessageHandler {
|
|||
PreferencesMessage::ImaginateRefreshFrequency { seconds } => {
|
||||
self.imaginate_refresh_frequency = seconds;
|
||||
responses.add(PortfolioMessage::ImaginateCheckServerStatus);
|
||||
responses.add(PortfolioMessage::ImaginatePreferences);
|
||||
}
|
||||
PreferencesMessage::ImaginateServerHostname { hostname } => {
|
||||
let initial = hostname.clone();
|
||||
|
@ -55,7 +66,9 @@ impl MessageHandler<PreferencesMessage, ()> for PreferencesMessageHandler {
|
|||
}
|
||||
|
||||
self.imaginate_server_hostname = hostname;
|
||||
responses.add(PortfolioMessage::ImaginateServerHostname);
|
||||
responses.add(PortfolioMessage::ImaginateCheckServerStatus);
|
||||
responses.add(PortfolioMessage::ImaginatePreferences);
|
||||
}
|
||||
PreferencesMessage::ModifyLayout { zoom_with_scroll } => {
|
||||
self.zoom_with_scroll = zoom_with_scroll;
|
||||
|
|
|
@ -10,16 +10,16 @@ use document_legacy::{LayerId, Operation};
|
|||
use graph_craft::document::value::TaggedValue;
|
||||
use graph_craft::document::{generate_uuid, DocumentNodeImplementation, NodeId, NodeNetwork};
|
||||
use graph_craft::graphene_compiler::Compiler;
|
||||
use graph_craft::imaginate_input::ImaginatePreferences;
|
||||
use graph_craft::{concrete, Type, TypeDescriptor};
|
||||
use graphene_core::application_io::ApplicationIo;
|
||||
use graphene_core::application_io::{ApplicationIo, NodeGraphUpdateMessage, NodeGraphUpdateSender};
|
||||
use graphene_core::raster::{Image, ImageFrame};
|
||||
use graphene_core::renderer::{SvgSegment, SvgSegmentList};
|
||||
use graphene_core::text::FontCache;
|
||||
use graphene_core::vector::style::ViewMode;
|
||||
|
||||
use graphene_core::{Color, SurfaceFrame, SurfaceId};
|
||||
use graphene_std::wasm_application_io::WasmApplicationIo;
|
||||
use graphene_std::wasm_application_io::WasmEditorApi;
|
||||
use graphene_std::wasm_application_io::{WasmApplicationIo, WasmEditorApi};
|
||||
use interpreted_executor::dynamic_executor::DynamicExecutor;
|
||||
|
||||
use glam::{DAffine2, DVec2};
|
||||
|
@ -33,8 +33,9 @@ pub struct NodeRuntime {
|
|||
pub(crate) executor: DynamicExecutor,
|
||||
font_cache: FontCache,
|
||||
receiver: Receiver<NodeRuntimeMessage>,
|
||||
sender: Sender<GenerationResponse>,
|
||||
sender: InternalNodeGraphUpdateSender,
|
||||
wasm_io: Option<WasmApplicationIo>,
|
||||
imaginate_preferences: ImaginatePreferences,
|
||||
pub(crate) thumbnails: HashMap<LayerId, HashMap<NodeId, SvgSegmentList>>,
|
||||
canvas_cache: HashMap<Vec<LayerId>, SurfaceId>,
|
||||
}
|
||||
|
@ -42,6 +43,7 @@ pub struct NodeRuntime {
|
|||
enum NodeRuntimeMessage {
|
||||
GenerationRequest(GenerationRequest),
|
||||
FontCacheUpdate(FontCache),
|
||||
ImaginatePreferencesUpdate(ImaginatePreferences),
|
||||
}
|
||||
|
||||
pub(crate) struct GenerationRequest {
|
||||
|
@ -50,6 +52,7 @@ pub(crate) struct GenerationRequest {
|
|||
path: Vec<LayerId>,
|
||||
image_frame: Option<ImageFrame<Color>>,
|
||||
}
|
||||
|
||||
pub(crate) struct GenerationResponse {
|
||||
generation_id: u64,
|
||||
result: Result<TaggedValue, String>,
|
||||
|
@ -57,18 +60,38 @@ pub(crate) struct GenerationResponse {
|
|||
new_thumbnails: HashMap<LayerId, HashMap<NodeId, SvgSegmentList>>,
|
||||
}
|
||||
|
||||
enum NodeGraphUpdate {
|
||||
GenerationResponse(GenerationResponse),
|
||||
NodeGraphUpdateMessage(NodeGraphUpdateMessage),
|
||||
}
|
||||
|
||||
struct InternalNodeGraphUpdateSender(Sender<NodeGraphUpdate>);
|
||||
|
||||
impl InternalNodeGraphUpdateSender {
|
||||
fn send_generation_response(&self, response: GenerationResponse) {
|
||||
self.0.send(NodeGraphUpdate::GenerationResponse(response)).expect("Failed to send response")
|
||||
}
|
||||
}
|
||||
|
||||
impl NodeGraphUpdateSender for InternalNodeGraphUpdateSender {
|
||||
fn send(&self, message: NodeGraphUpdateMessage) {
|
||||
self.0.send(NodeGraphUpdate::NodeGraphUpdateMessage(message)).expect("Failed to send response")
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
pub(crate) static NODE_RUNTIME: Rc<RefCell<Option<NodeRuntime>>> = Rc::new(RefCell::new(None));
|
||||
}
|
||||
|
||||
impl NodeRuntime {
|
||||
fn new(receiver: Receiver<NodeRuntimeMessage>, sender: Sender<GenerationResponse>) -> Self {
|
||||
fn new(receiver: Receiver<NodeRuntimeMessage>, sender: Sender<NodeGraphUpdate>) -> Self {
|
||||
let executor = DynamicExecutor::default();
|
||||
Self {
|
||||
executor,
|
||||
receiver,
|
||||
sender,
|
||||
sender: InternalNodeGraphUpdateSender(sender),
|
||||
font_cache: FontCache::default(),
|
||||
imaginate_preferences: Default::default(),
|
||||
thumbnails: Default::default(),
|
||||
wasm_io: None,
|
||||
canvas_cache: Default::default(),
|
||||
|
@ -80,13 +103,14 @@ impl NodeRuntime {
|
|||
// This should be avoided in the future.
|
||||
requests.reverse();
|
||||
requests.dedup_by_key(|x| match x {
|
||||
NodeRuntimeMessage::FontCacheUpdate(_) => None,
|
||||
NodeRuntimeMessage::GenerationRequest(x) => Some(x.path.clone()),
|
||||
_ => None,
|
||||
});
|
||||
requests.reverse();
|
||||
for request in requests {
|
||||
match request {
|
||||
NodeRuntimeMessage::FontCacheUpdate(font_cache) => self.font_cache = font_cache,
|
||||
NodeRuntimeMessage::ImaginatePreferencesUpdate(preferences) => self.imaginate_preferences = preferences,
|
||||
NodeRuntimeMessage::GenerationRequest(GenerationRequest {
|
||||
generation_id,
|
||||
graph,
|
||||
|
@ -105,7 +129,7 @@ impl NodeRuntime {
|
|||
updates: responses,
|
||||
new_thumbnails: self.thumbnails.clone(),
|
||||
};
|
||||
self.sender.send(response).expect("Failed to send response");
|
||||
self.sender.send_generation_response(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -134,6 +158,8 @@ impl NodeRuntime {
|
|||
font_cache: &self.font_cache,
|
||||
image_frame,
|
||||
application_io: &self.wasm_io.as_ref().unwrap(),
|
||||
node_graph_message_sender: &self.sender,
|
||||
imaginate_preferences: &self.imaginate_preferences,
|
||||
};
|
||||
|
||||
// We assume only one output
|
||||
|
@ -240,7 +266,7 @@ pub async fn run_node_graph() {
|
|||
#[derive(Debug)]
|
||||
pub struct NodeGraphExecutor {
|
||||
sender: Sender<NodeRuntimeMessage>,
|
||||
receiver: Receiver<GenerationResponse>,
|
||||
receiver: Receiver<NodeGraphUpdate>,
|
||||
// TODO: This is a memory leak since layers are never removed
|
||||
pub(crate) last_output_type: HashMap<Vec<LayerId>, Option<Type>>,
|
||||
pub(crate) thumbnails: HashMap<LayerId, HashMap<NodeId, SvgSegmentList>>,
|
||||
|
@ -294,10 +320,31 @@ impl NodeGraphExecutor {
|
|||
self.sender.send(NodeRuntimeMessage::FontCacheUpdate(font_cache)).expect("Failed to send font cache update");
|
||||
}
|
||||
|
||||
pub fn update_imaginate_preferences(&self, imaginate_preferences: ImaginatePreferences) {
|
||||
self.sender
|
||||
.send(NodeRuntimeMessage::ImaginatePreferencesUpdate(imaginate_preferences))
|
||||
.expect("Failed to send imaginate preferences");
|
||||
}
|
||||
|
||||
pub fn previous_output_type(&self, path: &[LayerId]) -> Option<Type> {
|
||||
self.last_output_type.get(path).cloned().flatten()
|
||||
}
|
||||
|
||||
pub fn introspect_node_in_network<T: std::any::Any + core::fmt::Debug, U, F1: FnOnce(&NodeNetwork) -> Option<NodeId>, F2: FnOnce(&T) -> U>(
|
||||
&mut self,
|
||||
network: &NodeNetwork,
|
||||
node_path: &[NodeId],
|
||||
find_node: F1,
|
||||
extract_data: F2,
|
||||
) -> Option<U> {
|
||||
let wrapping_document_node = network.nodes.get(node_path.last()?)?;
|
||||
let DocumentNodeImplementation::Network(wrapped_network) = &wrapping_document_node.implementation else { return None; };
|
||||
let introspection_node = find_node(&wrapped_network)?;
|
||||
let introspection = self.introspect_node(&[node_path, &[introspection_node]].concat())?;
|
||||
let downcasted: &T = <dyn std::any::Any>::downcast_ref(introspection.as_ref())?;
|
||||
Some(extract_data(downcasted))
|
||||
}
|
||||
|
||||
/// Encodes an image into a format using the image crate
|
||||
fn encode_img(image: Image<Color>, resize: Option<DVec2>, format: image::ImageOutputFormat) -> Result<(Vec<u8>, (u32, u32)), String> {
|
||||
use image::{ImageBuffer, Rgba};
|
||||
|
@ -334,13 +381,12 @@ impl NodeGraphExecutor {
|
|||
})
|
||||
}
|
||||
|
||||
/// Evaluates a node graph, computing either the Imaginate node or the entire graph
|
||||
/// Evaluates a node graph, computing the entire graph
|
||||
pub fn submit_node_graph_evaluation(
|
||||
&mut self,
|
||||
(document_id, documents): (u64, &mut HashMap<u64, DocumentMessageHandler>),
|
||||
layer_path: Vec<LayerId>,
|
||||
(input_image_data, (width, height)): (Vec<u8>, (u32, u32)),
|
||||
_imaginate_node: Option<Vec<NodeId>>,
|
||||
_persistent_data: (&PreferencesMessageHandler, &PersistentData),
|
||||
_responses: &mut VecDeque<Message>,
|
||||
) -> Result<(), String> {
|
||||
|
@ -365,11 +411,6 @@ impl NodeGraphExecutor {
|
|||
let transform = DAffine2::IDENTITY;
|
||||
let image_frame = ImageFrame { image, transform };
|
||||
|
||||
// Special execution path for generating Imaginate (as generation requires IO from outside node graph)
|
||||
/*if let Some(imaginate_node) = imaginate_node {
|
||||
responses.add(self.generate_imaginate(network, imaginate_node, (document, document_id), layer_path, editor_api, persistent_data)?);
|
||||
return Ok(());
|
||||
}*/
|
||||
// Execute the node graph
|
||||
let generation_id = self.queue_execution(network, Some(image_frame), layer_path.clone());
|
||||
|
||||
|
@ -381,26 +422,32 @@ impl NodeGraphExecutor {
|
|||
pub fn poll_node_graph_evaluation(&mut self, responses: &mut VecDeque<Message>) -> Result<(), String> {
|
||||
let results = self.receiver.try_iter().collect::<Vec<_>>();
|
||||
for response in results {
|
||||
let GenerationResponse {
|
||||
generation_id,
|
||||
result,
|
||||
updates,
|
||||
new_thumbnails,
|
||||
} = response;
|
||||
self.thumbnails = new_thumbnails;
|
||||
let node_graph_output = result.map_err(|e| format!("Node graph evaluation failed: {:?}", e))?;
|
||||
let execution_context = self.futures.remove(&generation_id).ok_or_else(|| "Invalid generation ID".to_string())?;
|
||||
responses.extend(updates);
|
||||
self.process_node_graph_output(node_graph_output, execution_context.layer_path.clone(), responses, execution_context.document_id)?;
|
||||
responses.add(DocumentMessage::LayerChanged {
|
||||
affected_layer_path: execution_context.layer_path,
|
||||
});
|
||||
responses.add(DocumentMessage::RenderDocument);
|
||||
responses.add(ArtboardMessage::RenderArtboards);
|
||||
responses.add(DocumentMessage::DocumentStructureChanged);
|
||||
responses.add(BroadcastEvent::DocumentIsDirty);
|
||||
responses.add(DocumentMessage::DirtyRenderDocument);
|
||||
responses.add(DocumentMessage::Overlays(OverlaysMessage::Rerender));
|
||||
match response {
|
||||
NodeGraphUpdate::GenerationResponse(GenerationResponse {
|
||||
generation_id,
|
||||
result,
|
||||
updates,
|
||||
new_thumbnails,
|
||||
}) => {
|
||||
self.thumbnails = new_thumbnails;
|
||||
let node_graph_output = result.map_err(|e| format!("Node graph evaluation failed: {:?}", e))?;
|
||||
let execution_context = self.futures.remove(&generation_id).ok_or_else(|| "Invalid generation ID".to_string())?;
|
||||
responses.extend(updates);
|
||||
self.process_node_graph_output(node_graph_output, execution_context.layer_path.clone(), responses, execution_context.document_id)?;
|
||||
responses.add(DocumentMessage::LayerChanged {
|
||||
affected_layer_path: execution_context.layer_path,
|
||||
});
|
||||
responses.add(DocumentMessage::RenderDocument);
|
||||
responses.add(ArtboardMessage::RenderArtboards);
|
||||
responses.add(DocumentMessage::DocumentStructureChanged);
|
||||
responses.add(BroadcastEvent::DocumentIsDirty);
|
||||
responses.add(DocumentMessage::DirtyRenderDocument);
|
||||
responses.add(DocumentMessage::Overlays(OverlaysMessage::Rerender));
|
||||
}
|
||||
NodeGraphUpdate::NodeGraphUpdateMessage(NodeGraphUpdateMessage::ImaginateStatusUpdate) => {
|
||||
responses.add(DocumentMessage::PropertiesPanel(PropertiesPanelMessage::ResendActiveProperties))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import {writable} from "svelte/store";
|
||||
|
||||
import { downloadFileText, downloadFileBlob, upload, downloadFileURL } from "@graphite/utility-functions/files";
|
||||
import { imaginateGenerate, imaginateCheckConnection, imaginateTerminate, updateBackendImage } from "@graphite/utility-functions/imaginate";
|
||||
import { extractPixelData, imageToPNG, rasterizeSVG, rasterizeSVGCanvas } from "@graphite/utility-functions/rasterization";
|
||||
import { type Editor } from "@graphite/wasm-communication/editor";
|
||||
import {
|
||||
|
@ -13,8 +12,6 @@ import {
|
|||
TriggerDownloadRaster,
|
||||
TriggerDownloadTextFile,
|
||||
TriggerImaginateCheckServerStatus,
|
||||
TriggerImaginateGenerate,
|
||||
TriggerImaginateTerminate,
|
||||
TriggerImport,
|
||||
TriggerOpenDocument,
|
||||
TriggerRasterizeRegionBelowLayer,
|
||||
|
@ -85,37 +82,6 @@ export function createPortfolioState(editor: Editor) {
|
|||
// Have the browser download the file to the user's disk
|
||||
downloadFileBlob(name, blob);
|
||||
});
|
||||
editor.subscriptions.subscribeJsMessage(TriggerImaginateCheckServerStatus, async (triggerImaginateCheckServerStatus) => {
|
||||
const { hostname } = triggerImaginateCheckServerStatus;
|
||||
|
||||
imaginateCheckConnection(hostname, editor);
|
||||
});
|
||||
editor.subscriptions.subscribeJsMessage(TriggerImaginateGenerate, async (triggerImaginateGenerate) => {
|
||||
const { documentId, layerPath, nodePath, hostname, refreshFrequency, baseImage, maskImage, maskPaintMode, maskBlurPx, maskFillContent, parameters } = triggerImaginateGenerate;
|
||||
|
||||
// Handle img2img mode
|
||||
let image: Blob | undefined;
|
||||
if (parameters.denoisingStrength !== undefined && baseImage !== undefined) {
|
||||
const buffer = new Uint8Array(baseImage.imageData.values()).buffer;
|
||||
|
||||
image = new Blob([buffer], { type: baseImage.mime });
|
||||
updateBackendImage(editor, image, documentId, layerPath, nodePath);
|
||||
}
|
||||
|
||||
// Handle layer mask
|
||||
let mask: Blob | undefined;
|
||||
if (maskImage !== undefined) {
|
||||
// Rasterize the SVG to an image file
|
||||
mask = await rasterizeSVG(maskImage.svg, maskImage.size[0], maskImage.size[1], "image/png");
|
||||
}
|
||||
|
||||
imaginateGenerate(parameters, image, mask, maskPaintMode, maskBlurPx, maskFillContent, hostname, refreshFrequency, documentId, layerPath, nodePath, editor);
|
||||
});
|
||||
editor.subscriptions.subscribeJsMessage(TriggerImaginateTerminate, async (triggerImaginateTerminate) => {
|
||||
const { documentId, layerPath, nodePath, hostname } = triggerImaginateTerminate;
|
||||
|
||||
imaginateTerminate(hostname, documentId, layerPath, nodePath, editor);
|
||||
});
|
||||
editor.subscriptions.subscribeJsMessage(UpdateImageData, (updateImageData) => {
|
||||
updateImageData.imageData.forEach(async (element) => {
|
||||
const buffer = new Uint8Array(element.imageData.values()).buffer;
|
||||
|
|
|
@ -1,367 +0,0 @@
|
|||
/* eslint-disable camelcase */
|
||||
|
||||
// import { escapeJSON } from "@graphite/utility-functions/escape";
|
||||
import { blobToBase64 } from "@graphite/utility-functions/files";
|
||||
import { type RequestResult, requestWithUploadDownloadProgress } from "@graphite/utility-functions/network";
|
||||
import { type Editor } from "@graphite/wasm-communication/editor";
|
||||
import type { XY } from "@graphite/wasm-communication/messages";
|
||||
import { type ImaginateGenerationParameters } from "@graphite/wasm-communication/messages";
|
||||
|
||||
const MAX_POLLING_RETRIES = 4;
|
||||
const SERVER_STATUS_CHECK_TIMEOUT = 5000;
|
||||
const PROGRESS_EVERY_N_STEPS = 5;
|
||||
|
||||
let timer: NodeJS.Timeout | undefined;
|
||||
let terminated = false;
|
||||
|
||||
let generatingAbortRequest: XMLHttpRequest | undefined;
|
||||
let pollingAbortController = new AbortController();
|
||||
let statusAbortController = new AbortController();
|
||||
|
||||
// PUBLICLY CALLABLE FUNCTIONS
|
||||
|
||||
export async function imaginateGenerate(
|
||||
parameters: ImaginateGenerationParameters,
|
||||
image: Blob | undefined,
|
||||
mask: Blob | undefined,
|
||||
maskPaintMode: string,
|
||||
maskBlurPx: number,
|
||||
maskFillContent: string,
|
||||
hostname: string,
|
||||
refreshFrequency: number,
|
||||
documentId: bigint,
|
||||
layerPath: BigUint64Array,
|
||||
nodePath: BigUint64Array,
|
||||
editor: Editor
|
||||
): Promise<void> {
|
||||
// Ignore a request to generate a new image while another is already being generated
|
||||
if (generatingAbortRequest !== undefined) return;
|
||||
|
||||
terminated = false;
|
||||
|
||||
// Immediately set the progress to 0% so the backend knows to update its layout
|
||||
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, 0, "Beginning");
|
||||
|
||||
// Initiate a request to the computation server
|
||||
const discloseUploadingProgress = (progress: number): void => {
|
||||
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, progress * 100, "Uploading");
|
||||
};
|
||||
const { uploaded, result, xhr } = await generate(discloseUploadingProgress, hostname, image, mask, maskPaintMode, maskBlurPx, maskFillContent, parameters);
|
||||
generatingAbortRequest = xhr;
|
||||
|
||||
try {
|
||||
// Wait until the request is fully uploaded, which could be slow if the img2img source is large and the user is on a slow connection
|
||||
await uploaded;
|
||||
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, 0, "Generating");
|
||||
|
||||
// Begin polling for updates to the in-progress image generation at the specified interval
|
||||
// Don't poll if the chosen interval is 0, or if the chosen sampling method does not support polling
|
||||
if (refreshFrequency > 0) {
|
||||
const interval = Math.max(refreshFrequency * 1000, 500);
|
||||
scheduleNextPollingUpdate(interval, Date.now(), 0, editor, hostname, documentId, layerPath, nodePath, parameters.resolution);
|
||||
}
|
||||
|
||||
// Wait for the final image to be returned by the initial request containing either the full image or the last frame if it was terminated by the user
|
||||
const { body, status } = await result;
|
||||
if (status < 200 || status > 299) {
|
||||
throw new Error(`Request to server failed to return a 200-level status code (${status})`);
|
||||
}
|
||||
|
||||
// Extract the final image from the response and convert it to a data blob
|
||||
const base64Data = JSON.parse(body)?.images?.[0] as string | undefined;
|
||||
const base64 = typeof base64Data === "string" && base64Data.length > 0 ? `data:image/png;base64,${base64Data}` : undefined;
|
||||
if (!base64) throw new Error("Could not read final image result from server response");
|
||||
const blob = await (await fetch(base64)).blob();
|
||||
|
||||
// Send the backend an updated status
|
||||
const percent = terminated ? undefined : 100;
|
||||
const newStatus = terminated ? "Terminated" : "Idle";
|
||||
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, percent, newStatus);
|
||||
|
||||
// Send the backend a blob URL for the final image
|
||||
updateBackendImage(editor, blob, documentId, layerPath, nodePath);
|
||||
} catch {
|
||||
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, undefined, "Terminated");
|
||||
|
||||
await imaginateCheckConnection(hostname, editor);
|
||||
}
|
||||
|
||||
abortAndResetGenerating();
|
||||
abortAndResetPolling();
|
||||
}
|
||||
|
||||
export async function imaginateTerminate(hostname: string, documentId: bigint, layerPath: BigUint64Array, nodePath: BigUint64Array, editor: Editor): Promise<void> {
|
||||
terminated = true;
|
||||
abortAndResetPolling();
|
||||
|
||||
try {
|
||||
await terminate(hostname);
|
||||
|
||||
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, undefined, "Terminating");
|
||||
} catch {
|
||||
abortAndResetGenerating();
|
||||
abortAndResetPolling();
|
||||
|
||||
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, undefined, "Terminated");
|
||||
|
||||
await imaginateCheckConnection(hostname, editor);
|
||||
}
|
||||
}
|
||||
|
||||
export async function imaginateCheckConnection(hostname: string, editor: Editor): Promise<void> {
|
||||
const serverReached = await checkConnection(hostname);
|
||||
editor.instance.setImaginateServerStatus(serverReached);
|
||||
}
|
||||
|
||||
// Converts the blob image into a list of pixels using an invisible canvas.
|
||||
export async function updateBackendImage(editor: Editor, blob: Blob, documentId: bigint, layerPath: BigUint64Array, nodePath: BigUint64Array): Promise<void> {
|
||||
const image = await createImageBitmap(blob);
|
||||
const canvas = document.createElement("canvas");
|
||||
canvas.width = image.width;
|
||||
canvas.height = image.height;
|
||||
const ctx = canvas.getContext("2d");
|
||||
if (!ctx) throw new Error("Could not create canvas context");
|
||||
ctx.drawImage(image, 0, 0);
|
||||
|
||||
// Send the backend the blob data to be stored persistently in the layer
|
||||
const imageData = ctx.getImageData(0, 0, image.width, image.height);
|
||||
const u8Array = new Uint8Array(imageData.data);
|
||||
|
||||
editor.instance.setImaginateImageData(documentId, layerPath, nodePath, u8Array, imageData.width, imageData.height);
|
||||
}
|
||||
|
||||
// ABORTING AND RESETTING HELPERS
|
||||
|
||||
function abortAndResetGenerating(): void {
|
||||
generatingAbortRequest?.abort();
|
||||
generatingAbortRequest = undefined;
|
||||
}
|
||||
|
||||
function abortAndResetPolling(): void {
|
||||
pollingAbortController.abort();
|
||||
pollingAbortController = new AbortController();
|
||||
clearTimeout(timer);
|
||||
}
|
||||
|
||||
// POLLING IMPLEMENTATION DETAILS
|
||||
|
||||
function scheduleNextPollingUpdate(
|
||||
interval: number,
|
||||
timeoutBegan: number,
|
||||
pollingRetries: number,
|
||||
editor: Editor,
|
||||
hostname: string,
|
||||
documentId: bigint,
|
||||
layerPath: BigUint64Array,
|
||||
nodePath: BigUint64Array,
|
||||
resolution: XY
|
||||
): void {
|
||||
// Pick a future time that keeps to the user-requested interval if possible, but on slower connections will go as fast as possible without overlapping itself
|
||||
const nextPollTimeGoal = timeoutBegan + interval;
|
||||
const timeFromNow = Math.max(0, nextPollTimeGoal - Date.now());
|
||||
|
||||
timer = setTimeout(async () => {
|
||||
const nextTimeoutBegan = Date.now();
|
||||
|
||||
try {
|
||||
const [blob, percentComplete] = await pollImage(hostname);
|
||||
|
||||
// After waiting for the polling result back from the server, if during that intervening time the user has terminated the generation, exit so we don't overwrite that terminated status
|
||||
if (terminated) return;
|
||||
|
||||
if (blob) updateBackendImage(editor, blob, documentId, layerPath, nodePath);
|
||||
editor.instance.setImaginateGeneratingStatus(documentId, layerPath, nodePath, percentComplete, "Generating");
|
||||
|
||||
scheduleNextPollingUpdate(interval, nextTimeoutBegan, 0, editor, hostname, documentId, layerPath, nodePath, resolution);
|
||||
} catch {
|
||||
if (generatingAbortRequest === undefined) return;
|
||||
|
||||
if (pollingRetries + 1 > MAX_POLLING_RETRIES) {
|
||||
abortAndResetGenerating();
|
||||
abortAndResetPolling();
|
||||
|
||||
await imaginateCheckConnection(hostname, editor);
|
||||
} else {
|
||||
scheduleNextPollingUpdate(interval, nextTimeoutBegan, pollingRetries + 1, editor, hostname, documentId, layerPath, nodePath, resolution);
|
||||
}
|
||||
}
|
||||
}, timeFromNow);
|
||||
}
|
||||
|
||||
// API COMMUNICATION FUNCTIONS
|
||||
|
||||
async function pollImage(hostname: string): Promise<[Blob | undefined, number]> {
|
||||
// Fetch the percent progress and in-progress image from the API
|
||||
const result = await fetch(`${hostname}sdapi/v1/progress`, { signal: pollingAbortController.signal, method: "GET" });
|
||||
const { current_image, progress } = await result.json();
|
||||
|
||||
// Convert to a usable format
|
||||
const progressPercent = progress * 100;
|
||||
const base64 = typeof current_image === "string" && current_image.length > 0 ? `data:image/png;base64,${current_image}` : undefined;
|
||||
|
||||
// Deal with a missing image
|
||||
if (!base64) {
|
||||
// The image is not ready yet (because it's only had a few samples since generation began), but we do have a progress percentage
|
||||
if (!Number.isNaN(progressPercent) && progressPercent >= 0 && progressPercent <= 100) {
|
||||
return [undefined, progressPercent];
|
||||
}
|
||||
|
||||
// Something else is wrong and the image wasn't provided as expected
|
||||
return Promise.reject();
|
||||
}
|
||||
|
||||
// The image was provided so we turn it into a data blob
|
||||
const blob = await (await fetch(base64)).blob();
|
||||
return [blob, progressPercent];
|
||||
}
|
||||
|
||||
async function generate(
|
||||
discloseUploadingProgress: (progress: number) => void,
|
||||
hostname: string,
|
||||
image: Blob | undefined,
|
||||
mask: Blob | undefined,
|
||||
maskPaintMode: string,
|
||||
maskBlurPx: number,
|
||||
maskFillContent: string,
|
||||
parameters: ImaginateGenerationParameters
|
||||
): Promise<{
|
||||
uploaded: Promise<void>;
|
||||
result: Promise<RequestResult>;
|
||||
xhr?: XMLHttpRequest;
|
||||
}> {
|
||||
let body;
|
||||
let endpoint;
|
||||
if (image === undefined || parameters.denoisingStrength === undefined) {
|
||||
endpoint = `${hostname}sdapi/v1/txt2img`;
|
||||
|
||||
body = {
|
||||
// enable_hr: false,
|
||||
// denoising_strength: 0,
|
||||
// firstphase_width: 0,
|
||||
// firstphase_height: 0,
|
||||
prompt: parameters.prompt,
|
||||
// styles: [],
|
||||
seed: Number(parameters.seed),
|
||||
// subseed: -1,
|
||||
// subseed_strength: 0,
|
||||
// seed_resize_from_h: -1,
|
||||
// seed_resize_from_w: -1,
|
||||
// batch_size: 1,
|
||||
// n_iter: 1,
|
||||
steps: parameters.samples,
|
||||
cfg_scale: parameters.cfgScale,
|
||||
width: parameters.resolution.x,
|
||||
height: parameters.resolution.y,
|
||||
restore_faces: parameters.restoreFaces,
|
||||
tiling: parameters.tiling,
|
||||
negative_prompt: parameters.negativePrompt,
|
||||
// eta: 0,
|
||||
// s_churn: 0,
|
||||
// s_tmax: 0,
|
||||
// s_tmin: 0,
|
||||
// s_noise: 1,
|
||||
override_settings: {
|
||||
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
|
||||
},
|
||||
sampler_index: parameters.samplingMethod,
|
||||
};
|
||||
} else {
|
||||
const sourceImageBase64 = await blobToBase64(image);
|
||||
const maskImageBase64 = mask ? await blobToBase64(mask) : "";
|
||||
|
||||
const maskFillContentIndexes = ["Fill", "Original", "LatentNoise", "LatentNothing"];
|
||||
const maskFillContentIndexFound = maskFillContentIndexes.indexOf(maskFillContent);
|
||||
const maskFillContentIndex = maskFillContentIndexFound === -1 ? undefined : maskFillContentIndexFound;
|
||||
|
||||
const maskInvert = maskPaintMode === "Inpaint" ? 1 : 0;
|
||||
|
||||
endpoint = `${hostname}sdapi/v1/img2img`;
|
||||
|
||||
body = {
|
||||
init_images: [sourceImageBase64],
|
||||
// resize_mode: 0,
|
||||
denoising_strength: parameters.denoisingStrength,
|
||||
mask: mask && maskImageBase64,
|
||||
mask_blur: mask && maskBlurPx,
|
||||
inpainting_fill: mask && maskFillContentIndex,
|
||||
inpaint_full_res: mask && false,
|
||||
// inpaint_full_res_padding: 0,
|
||||
inpainting_mask_invert: mask && maskInvert,
|
||||
prompt: parameters.prompt,
|
||||
// styles: [],
|
||||
seed: Number(parameters.seed),
|
||||
// subseed: -1,
|
||||
// subseed_strength: 0,
|
||||
// seed_resize_from_h: -1,
|
||||
// seed_resize_from_w: -1,
|
||||
// batch_size: 1,
|
||||
// n_iter: 1,
|
||||
steps: parameters.samples,
|
||||
cfg_scale: parameters.cfgScale,
|
||||
width: parameters.resolution.x,
|
||||
height: parameters.resolution.y,
|
||||
restore_faces: parameters.restoreFaces,
|
||||
tiling: parameters.tiling,
|
||||
negative_prompt: parameters.negativePrompt,
|
||||
// eta: 0,
|
||||
// s_churn: 0,
|
||||
// s_tmax: 0,
|
||||
// s_tmin: 0,
|
||||
// s_noise: 1,
|
||||
override_settings: {
|
||||
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
|
||||
img2img_fix_steps: true,
|
||||
},
|
||||
sampler_index: parameters.samplingMethod,
|
||||
// include_init_images: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Prepare a promise that will resolve after the outbound request upload is complete
|
||||
let uploadedResolve: () => void;
|
||||
let uploadedReject: () => void;
|
||||
const uploaded = new Promise<void>((resolve, reject): void => {
|
||||
uploadedResolve = resolve;
|
||||
uploadedReject = reject;
|
||||
});
|
||||
|
||||
// Fire off the request and, once the outbound request upload is complete, resolve the promise we defined above
|
||||
const uploadProgress = (progress: number): void => {
|
||||
if (progress < 1) {
|
||||
discloseUploadingProgress(progress);
|
||||
} else {
|
||||
uploadedResolve();
|
||||
}
|
||||
};
|
||||
const [result, xhr] = requestWithUploadDownloadProgress(endpoint, "POST", JSON.stringify(body), uploadProgress, abortAndResetPolling);
|
||||
result.catch(() => uploadedReject());
|
||||
|
||||
// Return the promise that resolves when the request upload is complete, the promise that resolves when the response download is complete, and the XHR so it can be aborted
|
||||
return { uploaded, result, xhr };
|
||||
}
|
||||
|
||||
async function terminate(hostname: string): Promise<void> {
|
||||
await fetch(`${hostname}sdapi/v1/interrupt`, { method: "POST" });
|
||||
}
|
||||
|
||||
async function checkConnection(hostname: string): Promise<boolean> {
|
||||
statusAbortController.abort();
|
||||
statusAbortController = new AbortController();
|
||||
|
||||
const timeout = setTimeout(() => statusAbortController.abort(), SERVER_STATUS_CHECK_TIMEOUT);
|
||||
|
||||
try {
|
||||
// Intentionally misuse this API endpoint by using it just to check for a code 200 response, regardless of what the result is
|
||||
const { status } = await fetch(`${hostname}sdapi/v1/progress?skip_current_image=true`, { signal: statusAbortController.signal, method: "GET" });
|
||||
|
||||
// This code means the server has indeed responded and the endpoint exists (otherwise it would be 404)
|
||||
if (status === 200) {
|
||||
clearTimeout(timeout);
|
||||
return true;
|
||||
}
|
||||
} catch {
|
||||
// Do nothing here
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
|
@ -95,3 +95,7 @@ export function createEditor() {
|
|||
subscriptions,
|
||||
};
|
||||
}
|
||||
|
||||
export function injectImaginatePollServerStatus() {
|
||||
window["editorInstance"]?.injectImaginatePollServerStatus()
|
||||
}
|
||||
|
|
|
@ -516,7 +516,7 @@ export class TriggerCopyToClipboardBlobUrl extends JsMessage {
|
|||
|
||||
export class TriggerDownloadBlobUrl extends JsMessage {
|
||||
readonly layerName!: string;
|
||||
|
||||
|
||||
readonly blobUrl!: string;
|
||||
}
|
||||
|
||||
|
@ -537,85 +537,6 @@ export class TriggerDownloadTextFile extends JsMessage {
|
|||
readonly name!: string;
|
||||
}
|
||||
|
||||
export class TriggerImaginateCheckServerStatus extends JsMessage {
|
||||
readonly hostname!: string;
|
||||
}
|
||||
|
||||
export class TriggerImaginateGenerate extends JsMessage {
|
||||
@Type(() => ImaginateGenerationParameters)
|
||||
readonly parameters!: ImaginateGenerationParameters;
|
||||
|
||||
@Type(() => ImaginateBaseImage)
|
||||
readonly baseImage!: ImaginateBaseImage | undefined;
|
||||
|
||||
@Type(() => ImaginateMaskImage)
|
||||
readonly maskImage: ImaginateMaskImage | undefined;
|
||||
|
||||
readonly maskPaintMode!: string;
|
||||
|
||||
readonly maskBlurPx!: number;
|
||||
|
||||
readonly maskFillContent!: string;
|
||||
|
||||
readonly hostname!: string;
|
||||
|
||||
readonly refreshFrequency!: number;
|
||||
|
||||
readonly documentId!: bigint;
|
||||
|
||||
readonly layerPath!: BigUint64Array;
|
||||
|
||||
readonly nodePath!: BigUint64Array;
|
||||
}
|
||||
|
||||
export class ImaginateMaskImage {
|
||||
readonly svg!: string;
|
||||
|
||||
readonly size!: [number, number];
|
||||
}
|
||||
|
||||
export class ImaginateBaseImage {
|
||||
readonly mime!: string;
|
||||
|
||||
readonly imageData!: Uint8Array;
|
||||
|
||||
@TupleToVec2
|
||||
readonly size!: [number, number];
|
||||
}
|
||||
|
||||
export class ImaginateGenerationParameters {
|
||||
readonly seed!: number;
|
||||
|
||||
readonly samples!: number;
|
||||
|
||||
readonly samplingMethod!: string;
|
||||
|
||||
readonly denoisingStrength!: number | undefined;
|
||||
|
||||
readonly cfgScale!: number;
|
||||
|
||||
readonly prompt!: string;
|
||||
|
||||
readonly negativePrompt!: string;
|
||||
|
||||
@BigIntTupleToVec2
|
||||
readonly resolution!: XY;
|
||||
|
||||
readonly restoreFaces!: boolean;
|
||||
|
||||
readonly tiling!: boolean;
|
||||
}
|
||||
|
||||
export class TriggerImaginateTerminate extends JsMessage {
|
||||
readonly documentId!: bigint;
|
||||
|
||||
readonly layerPath!: BigUint64Array;
|
||||
|
||||
readonly nodePath!: BigUint64Array;
|
||||
|
||||
readonly hostname!: string;
|
||||
}
|
||||
|
||||
export class TriggerRasterizeRegionBelowLayer extends JsMessage {
|
||||
readonly documentId!: bigint;
|
||||
|
||||
|
@ -778,7 +699,7 @@ export class ImaginateImageData {
|
|||
readonly mime!: string;
|
||||
|
||||
readonly imageData!: Uint8Array;
|
||||
|
||||
|
||||
readonly transform!: Float64Array ;
|
||||
}
|
||||
|
||||
|
@ -1404,9 +1325,6 @@ export const messageMakers: Record<string, MessageMaker> = {
|
|||
TriggerDownloadRaster,
|
||||
TriggerDownloadTextFile,
|
||||
TriggerFontLoad,
|
||||
TriggerImaginateCheckServerStatus,
|
||||
TriggerImaginateGenerate,
|
||||
TriggerImaginateTerminate,
|
||||
TriggerImport,
|
||||
TriggerIndexedDbRemoveDocument,
|
||||
TriggerIndexedDbWriteDocument,
|
||||
|
|
|
@ -30,9 +30,9 @@ graphene-core = { path = "../../node-graph/gcore", features = [
|
|||
] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
wasm-bindgen = { version = "=0.2.86" }
|
||||
serde-wasm-bindgen = "0.4.1"
|
||||
js-sys = "0.3.55"
|
||||
wasm-bindgen-futures = "0.4.33"
|
||||
serde-wasm-bindgen = "0.5.0"
|
||||
js-sys = "0.3.63"
|
||||
wasm-bindgen-futures = "0.4.36"
|
||||
ron = { version = "0.8", optional = true }
|
||||
bezier-rs = { path = "../../libraries/bezier-rs" }
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ use editor::application::Editor;
|
|||
use editor::consts::{FILE_SAVE_SUFFIX, GRAPHITE_DOCUMENT_VERSION};
|
||||
use editor::messages::input_mapper::utility_types::input_keyboard::ModifierKeys;
|
||||
use editor::messages::input_mapper::utility_types::input_mouse::{EditorMouseState, ScrollDelta, ViewportBounds};
|
||||
use editor::messages::portfolio::utility_types::{ImaginateServerStatus, Platform};
|
||||
use editor::messages::portfolio::utility_types::Platform;
|
||||
use editor::messages::prelude::*;
|
||||
use graph_craft::document::NodeId;
|
||||
use graphene_core::raster::color::Color;
|
||||
|
@ -586,63 +586,6 @@ impl JsEditorHandle {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sends the blob URL generated by JS to the Imaginate layer in the respective document
|
||||
#[wasm_bindgen(js_name = setImaginateImageData)]
|
||||
pub fn set_imaginate_image_data(&self, document_id: u64, layer_path: Vec<LayerId>, node_path: Vec<NodeId>, image_data: Vec<u8>, width: u32, height: u32) {
|
||||
let message = PortfolioMessage::ImaginateSetImageData {
|
||||
document_id,
|
||||
node_path,
|
||||
layer_path,
|
||||
image_data,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
self.dispatch(message);
|
||||
}
|
||||
|
||||
/// Notifies the Imaginate layer of a new percentage of completion and whether or not it's currently generating
|
||||
#[wasm_bindgen(js_name = setImaginateGeneratingStatus)]
|
||||
pub fn set_imaginate_generating_status(&self, document_id: u64, layer_path: Vec<LayerId>, node_path: Vec<NodeId>, percent: Option<f64>, status: String) {
|
||||
use graph_craft::imaginate_input::ImaginateStatus;
|
||||
|
||||
let status = match status.as_str() {
|
||||
"Idle" => ImaginateStatus::Idle,
|
||||
"Beginning" => ImaginateStatus::Beginning,
|
||||
"Uploading" => ImaginateStatus::Uploading(percent.expect("Percent needs to be supplied to set ImaginateStatus::Uploading")),
|
||||
"Generating" => ImaginateStatus::Generating,
|
||||
"Terminating" => ImaginateStatus::Terminating,
|
||||
"Terminated" => ImaginateStatus::Terminated,
|
||||
_ => panic!("Invalid string from JS for ImaginateStatus, received: {}", status),
|
||||
};
|
||||
|
||||
let percent = if matches!(status, ImaginateStatus::Uploading(_)) { None } else { percent };
|
||||
|
||||
let message = PortfolioMessage::ImaginateSetGeneratingStatus {
|
||||
document_id,
|
||||
layer_path,
|
||||
node_path,
|
||||
percent,
|
||||
status,
|
||||
};
|
||||
self.dispatch(message);
|
||||
}
|
||||
|
||||
/// Notifies the editor that the Imaginate server is available or unavailable
|
||||
#[wasm_bindgen(js_name = setImaginateServerStatus)]
|
||||
pub fn set_imaginate_server_status(&self, available: bool) {
|
||||
let message: Message = match available {
|
||||
true => PortfolioMessage::ImaginateSetServerStatus {
|
||||
status: ImaginateServerStatus::Connected,
|
||||
}
|
||||
.into(),
|
||||
false => PortfolioMessage::ImaginateSetServerStatus {
|
||||
status: ImaginateServerStatus::Unavailable,
|
||||
}
|
||||
.into(),
|
||||
};
|
||||
self.dispatch(message);
|
||||
}
|
||||
|
||||
/// Sends the blob URL generated by JS to the Imaginate layer in the respective document
|
||||
#[wasm_bindgen(js_name = renderGraphUsingRasterizedRegionBelowLayer)]
|
||||
pub fn render_graph_using_rasterized_region_below_layer(
|
||||
|
@ -793,6 +736,11 @@ impl JsEditorHandle {
|
|||
});
|
||||
frontend_messages.unwrap().unwrap_or_default()
|
||||
}
|
||||
|
||||
#[wasm_bindgen(js_name = injectImaginatePollServerStatus)]
|
||||
pub fn inject_imaginate_poll_server_status(&self) {
|
||||
self.dispatch(PortfolioMessage::ImaginatePollServerStatus);
|
||||
}
|
||||
}
|
||||
|
||||
// Needed to make JsEditorHandle functions pub to Rust.
|
||||
|
|
|
@ -12,4 +12,9 @@ futures = "0.3.25"
|
|||
log = "0.4"
|
||||
|
||||
[target.wasm32-unknown-unknown.dependencies]
|
||||
wasm-rs-async-executor = {version = "0.9.0", features = ["cooperative-browser", "debug", "requestIdleCallback"] }
|
||||
wasm-rs-async-executor = { version = "0.9.0", features = [
|
||||
"cooperative-browser",
|
||||
"debug",
|
||||
"requestIdleCallback",
|
||||
] }
|
||||
wasm-bindgen-futures = "0.4.36"
|
||||
|
|
|
@ -22,3 +22,8 @@ pub fn block_on<F: Future + 'static>(future: F) -> F::Output {
|
|||
#[cfg(not(target_arch = "wasm32"))]
|
||||
futures::executor::block_on(future)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn spawn<F: Future<Output = ()> + 'static>(future: F) {
|
||||
wasm_bindgen_futures::spawn_local(future);
|
||||
}
|
||||
|
|
|
@ -112,10 +112,25 @@ impl<T: ApplicationIo> ApplicationIo for &T {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum NodeGraphUpdateMessage {
|
||||
ImaginateStatusUpdate,
|
||||
}
|
||||
|
||||
pub trait NodeGraphUpdateSender {
|
||||
fn send(&self, message: NodeGraphUpdateMessage);
|
||||
}
|
||||
|
||||
pub trait GetImaginatePreferences {
|
||||
fn get_host_name(&self) -> &str;
|
||||
}
|
||||
|
||||
pub struct EditorApi<'a, Io> {
|
||||
pub image_frame: Option<ImageFrame<Color>>,
|
||||
pub font_cache: &'a FontCache,
|
||||
pub application_io: &'a Io,
|
||||
pub node_graph_message_sender: &'a dyn NodeGraphUpdateSender,
|
||||
pub imaginate_preferences: &'a dyn GetImaginatePreferences,
|
||||
}
|
||||
|
||||
impl<'a, Io> Clone for EditorApi<'a, Io> {
|
||||
|
@ -124,6 +139,8 @@ impl<'a, Io> Clone for EditorApi<'a, Io> {
|
|||
image_frame: self.image_frame.clone(),
|
||||
font_cache: self.font_cache,
|
||||
application_io: self.application_io,
|
||||
node_graph_message_sender: self.node_graph_message_sender,
|
||||
imaginate_preferences: self.imaginate_preferences,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -441,7 +441,7 @@ impl NodeNetwork {
|
|||
}
|
||||
|
||||
/// Check if the specified node id is connected to the output
|
||||
pub fn connected_to_output(&self, target_node_id: NodeId, ignore_imaginate: bool) -> bool {
|
||||
pub fn connected_to_output(&self, target_node_id: NodeId) -> bool {
|
||||
// If the node is the output then return true
|
||||
if self.outputs.iter().any(|&NodeOutput { node_id, .. }| node_id == target_node_id) {
|
||||
return true;
|
||||
|
@ -454,11 +454,6 @@ impl NodeNetwork {
|
|||
already_visited.extend(self.outputs.iter().map(|output| output.node_id));
|
||||
|
||||
while let Some(node) = stack.pop() {
|
||||
// Skip the imaginate node inputs
|
||||
if ignore_imaginate && node.name == "Imaginate" {
|
||||
continue;
|
||||
}
|
||||
|
||||
for input in &node.inputs {
|
||||
if let &NodeInput::Node { node_id: ref_id, .. } = input {
|
||||
// Skip if already viewed
|
||||
|
@ -680,7 +675,7 @@ impl NodeNetwork {
|
|||
|
||||
let mut dummy_input = NodeInput::ShortCircut(concrete!(()));
|
||||
std::mem::swap(&mut dummy_input, input);
|
||||
if let NodeInput::Value { tagged_value, exposed } = dummy_input {
|
||||
if let NodeInput::Value { mut tagged_value, exposed } = dummy_input {
|
||||
let value_node_id = gen_id();
|
||||
let merged_node_id = map_ids(id, value_node_id);
|
||||
let path = if let Some(mut new_path) = node.path.clone() {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::DocumentNode;
|
||||
use crate::graphene_compiler::Any;
|
||||
pub use crate::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateStatus};
|
||||
pub use crate::imaginate_input::{ImaginateCache, ImaginateController, ImaginateMaskStartingFill, ImaginateSamplingMethod};
|
||||
use crate::proto::{Any as DAny, FutureAny};
|
||||
|
||||
use graphene_core::raster::brush_cache::BrushCache;
|
||||
|
@ -27,7 +27,7 @@ pub enum TaggedValue {
|
|||
OptionalDVec2(Option<DVec2>),
|
||||
DAffine2(DAffine2),
|
||||
Image(graphene_core::raster::Image<Color>),
|
||||
RcImage(Option<Arc<graphene_core::raster::Image<Color>>>),
|
||||
ImaginateCache(ImaginateCache),
|
||||
ImageFrame(graphene_core::raster::ImageFrame<Color>),
|
||||
Color(graphene_core::raster::color::Color),
|
||||
Subpaths(Vec<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>),
|
||||
|
@ -36,7 +36,7 @@ pub enum TaggedValue {
|
|||
LuminanceCalculation(LuminanceCalculation),
|
||||
ImaginateSamplingMethod(ImaginateSamplingMethod),
|
||||
ImaginateMaskStartingFill(ImaginateMaskStartingFill),
|
||||
ImaginateStatus(ImaginateStatus),
|
||||
ImaginateController(ImaginateController),
|
||||
LayerPath(Option<Vec<u64>>),
|
||||
VectorData(graphene_core::vector::VectorData),
|
||||
Fill(graphene_core::vector::style::Fill),
|
||||
|
@ -83,7 +83,7 @@ impl Hash for TaggedValue {
|
|||
}
|
||||
Self::DAffine2(m) => m.to_cols_array().iter().for_each(|x| x.to_bits().hash(state)),
|
||||
Self::Image(i) => i.hash(state),
|
||||
Self::RcImage(i) => i.hash(state),
|
||||
Self::ImaginateCache(i) => i.hash(state),
|
||||
Self::Color(c) => c.hash(state),
|
||||
Self::Subpaths(s) => s.iter().for_each(|subpath| subpath.hash(state)),
|
||||
Self::RcSubpath(s) => s.hash(state),
|
||||
|
@ -91,7 +91,7 @@ impl Hash for TaggedValue {
|
|||
Self::LuminanceCalculation(l) => l.hash(state),
|
||||
Self::ImaginateSamplingMethod(m) => m.hash(state),
|
||||
Self::ImaginateMaskStartingFill(f) => f.hash(state),
|
||||
Self::ImaginateStatus(s) => s.hash(state),
|
||||
Self::ImaginateController(s) => s.hash(state),
|
||||
Self::LayerPath(p) => p.hash(state),
|
||||
Self::ImageFrame(i) => i.hash(state),
|
||||
Self::VectorData(vector_data) => vector_data.hash(state),
|
||||
|
@ -146,7 +146,7 @@ impl<'a> TaggedValue {
|
|||
TaggedValue::OptionalDVec2(x) => Box::new(x),
|
||||
TaggedValue::DAffine2(x) => Box::new(x),
|
||||
TaggedValue::Image(x) => Box::new(x),
|
||||
TaggedValue::RcImage(x) => Box::new(x),
|
||||
TaggedValue::ImaginateCache(x) => Box::new(x),
|
||||
TaggedValue::ImageFrame(x) => Box::new(x),
|
||||
TaggedValue::Color(x) => Box::new(x),
|
||||
TaggedValue::Subpaths(x) => Box::new(x),
|
||||
|
@ -155,7 +155,7 @@ impl<'a> TaggedValue {
|
|||
TaggedValue::LuminanceCalculation(x) => Box::new(x),
|
||||
TaggedValue::ImaginateSamplingMethod(x) => Box::new(x),
|
||||
TaggedValue::ImaginateMaskStartingFill(x) => Box::new(x),
|
||||
TaggedValue::ImaginateStatus(x) => Box::new(x),
|
||||
TaggedValue::ImaginateController(x) => Box::new(x),
|
||||
TaggedValue::LayerPath(x) => Box::new(x),
|
||||
TaggedValue::VectorData(x) => Box::new(x),
|
||||
TaggedValue::Fill(x) => Box::new(x),
|
||||
|
@ -210,7 +210,7 @@ impl<'a> TaggedValue {
|
|||
TaggedValue::DVec2(_) => concrete!(DVec2),
|
||||
TaggedValue::OptionalDVec2(_) => concrete!(Option<DVec2>),
|
||||
TaggedValue::Image(_) => concrete!(graphene_core::raster::Image<Color>),
|
||||
TaggedValue::RcImage(_) => concrete!(Option<Arc<graphene_core::raster::Image<Color>>>),
|
||||
TaggedValue::ImaginateCache(_) => concrete!(ImaginateCache),
|
||||
TaggedValue::ImageFrame(_) => concrete!(graphene_core::raster::ImageFrame<Color>),
|
||||
TaggedValue::Color(_) => concrete!(graphene_core::raster::Color),
|
||||
TaggedValue::Subpaths(_) => concrete!(Vec<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>),
|
||||
|
@ -218,7 +218,7 @@ impl<'a> TaggedValue {
|
|||
TaggedValue::BlendMode(_) => concrete!(BlendMode),
|
||||
TaggedValue::ImaginateSamplingMethod(_) => concrete!(ImaginateSamplingMethod),
|
||||
TaggedValue::ImaginateMaskStartingFill(_) => concrete!(ImaginateMaskStartingFill),
|
||||
TaggedValue::ImaginateStatus(_) => concrete!(ImaginateStatus),
|
||||
TaggedValue::ImaginateController(_) => concrete!(ImaginateController),
|
||||
TaggedValue::LayerPath(_) => concrete!(Option<Vec<u64>>),
|
||||
TaggedValue::DAffine2(_) => concrete!(DAffine2),
|
||||
TaggedValue::LuminanceCalculation(_) => concrete!(LuminanceCalculation),
|
||||
|
@ -263,7 +263,7 @@ impl<'a> TaggedValue {
|
|||
x if x == TypeId::of::<DVec2>() => Ok(TaggedValue::DVec2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<DVec2>>() => Ok(TaggedValue::OptionalDVec2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::Image<Color>>() => Ok(TaggedValue::Image(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<Arc<graphene_core::raster::Image<Color>>>>() => Ok(TaggedValue::RcImage(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateCache>() => Ok(TaggedValue::ImaginateCache(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::ImageFrame<Color>>() => Ok(TaggedValue::ImageFrame(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<graphene_core::raster::Color>() => Ok(TaggedValue::Color(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Vec<bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>>>() => Ok(TaggedValue::Subpaths(*downcast(input).unwrap())),
|
||||
|
@ -271,7 +271,7 @@ impl<'a> TaggedValue {
|
|||
x if x == TypeId::of::<BlendMode>() => Ok(TaggedValue::BlendMode(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateSamplingMethod>() => Ok(TaggedValue::ImaginateSamplingMethod(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateMaskStartingFill>() => Ok(TaggedValue::ImaginateMaskStartingFill(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateStatus>() => Ok(TaggedValue::ImaginateStatus(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<ImaginateController>() => Ok(TaggedValue::ImaginateController(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<Option<Vec<u64>>>() => Ok(TaggedValue::LayerPath(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<DAffine2>() => Ok(TaggedValue::DAffine2(*downcast(input).unwrap())),
|
||||
x if x == TypeId::of::<LuminanceCalculation>() => Ok(TaggedValue::LuminanceCalculation(*downcast(input).unwrap())),
|
||||
|
|
|
@ -1,50 +1,155 @@
|
|||
use dyn_any::{DynAny, StaticType};
|
||||
use glam::DVec2;
|
||||
use graphene_core::Color;
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, Mutex,
|
||||
};
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, DynAny, specta::Type)]
|
||||
#[derive(Default, Debug, Clone, DynAny, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateCache(Arc<Mutex<graphene_core::raster::Image<Color>>>);
|
||||
|
||||
impl ImaginateCache {
|
||||
pub fn into_inner(self) -> Arc<Mutex<graphene_core::raster::Image<Color>>> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::cmp::PartialEq for ImaginateCache {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
Arc::ptr_eq(&self.0, &other.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl core::hash::Hash for ImaginateCache {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
self.0.lock().unwrap().hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ImaginateTerminationHandle: Debug + Send + Sync + 'static {
|
||||
fn terminate(&self);
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
struct InternalImaginateControl {
|
||||
status: Mutex<ImaginateStatus>,
|
||||
trigger_regenerate: AtomicBool,
|
||||
#[serde(skip)]
|
||||
termination_sender: Mutex<Option<Box<dyn ImaginateTerminationHandle>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, DynAny, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateController(Arc<InternalImaginateControl>);
|
||||
|
||||
impl ImaginateController {
|
||||
pub fn get_status(&self) -> ImaginateStatus {
|
||||
self.0.status.lock().as_deref().cloned().unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn set_status(&self, status: ImaginateStatus) {
|
||||
if let Ok(mut lock) = self.0.status.lock() {
|
||||
*lock = status
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_regenerate_trigger(&self) -> bool {
|
||||
self.0.trigger_regenerate.swap(false, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn trigger_regenerate(&self) {
|
||||
self.0.trigger_regenerate.store(true, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn request_termination(&self) {
|
||||
if let Some(handle) = self.0.termination_sender.lock().ok().and_then(|mut lock| lock.take()) {
|
||||
handle.terminate()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_termination_handle<H: ImaginateTerminationHandle>(&self, handle: Box<H>) {
|
||||
if let Ok(mut lock) = self.0.termination_sender.lock() {
|
||||
*lock = Some(handle)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::cmp::PartialEq for ImaginateController {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
Arc::ptr_eq(&self.0, &other.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl core::hash::Hash for ImaginateController {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
core::ptr::hash(Arc::as_ptr(&self.0), state)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, DynAny, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum ImaginateStatus {
|
||||
#[default]
|
||||
Idle,
|
||||
Ready,
|
||||
ReadyDone,
|
||||
Beginning,
|
||||
Uploading(f64),
|
||||
Generating,
|
||||
Uploading,
|
||||
Generating(f64),
|
||||
Terminating,
|
||||
Terminated,
|
||||
Failed(String),
|
||||
}
|
||||
|
||||
impl ImaginateStatus {
|
||||
pub fn to_text(&self) -> Cow<'static, str> {
|
||||
match self {
|
||||
Self::Ready => Cow::Borrowed("Ready"),
|
||||
Self::ReadyDone => Cow::Borrowed("Done"),
|
||||
Self::Beginning => Cow::Borrowed("Beginning…"),
|
||||
Self::Uploading => Cow::Borrowed("Downloading Image…"),
|
||||
Self::Generating(percent) => Cow::Owned(format!("Generating {percent:.0}%")),
|
||||
Self::Terminating => Cow::Owned(format!("Terminating…")),
|
||||
Self::Terminated => Cow::Owned(format!("Terminated")),
|
||||
Self::Failed(err) => Cow::Owned(format!("Failed: {err}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::derived_hash_with_manual_eq)]
|
||||
impl core::hash::Hash for ImaginateStatus {
|
||||
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
|
||||
core::mem::discriminant(self).hash(state);
|
||||
match self {
|
||||
Self::Idle => 0.hash(state),
|
||||
Self::Beginning => 1.hash(state),
|
||||
Self::Uploading(f) => {
|
||||
2.hash(state);
|
||||
f.to_bits().hash(state);
|
||||
}
|
||||
Self::Generating => 3.hash(state),
|
||||
Self::Terminating => 4.hash(state),
|
||||
Self::Terminated => 5.hash(state),
|
||||
Self::Ready | Self::ReadyDone | Self::Beginning | Self::Uploading | Self::Terminating | Self::Terminated => (),
|
||||
Self::Generating(f) => f.to_bits().hash(state),
|
||||
Self::Failed(err) => err.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateBaseImage {
|
||||
pub mime: String,
|
||||
#[cfg_attr(feature = "serde", serde(rename = "imageData"))]
|
||||
pub image_data: Vec<u8>,
|
||||
pub size: DVec2,
|
||||
#[derive(PartialEq, Eq, Clone, Default, Debug)]
|
||||
pub enum ImaginateServerStatus {
|
||||
#[default]
|
||||
Unknown,
|
||||
Checking,
|
||||
Connected,
|
||||
Failed(String),
|
||||
Unavailable,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateMaskImage {
|
||||
pub svg: String,
|
||||
pub size: DVec2,
|
||||
impl ImaginateServerStatus {
|
||||
pub fn to_text(&self) -> Cow<'static, str> {
|
||||
match self {
|
||||
Self::Unknown | Self::Checking => Cow::Borrowed("Checking..."),
|
||||
Self::Connected => Cow::Borrowed("Connected"),
|
||||
Self::Failed(err) => Cow::Owned(err.clone()),
|
||||
Self::Unavailable => Cow::Borrowed("Unavailable"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
|
@ -180,24 +285,26 @@ impl std::fmt::Display for ImaginateSamplingMethod {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, specta::Type)]
|
||||
#[derive(Clone, Debug, PartialEq, Hash, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateGenerationParameters {
|
||||
pub seed: u64,
|
||||
pub samples: u32,
|
||||
/// Use `ImaginateSamplingMethod::api_value()` to generate this string
|
||||
#[cfg_attr(feature = "serde", serde(rename = "samplingMethod"))]
|
||||
pub sampling_method: String,
|
||||
#[cfg_attr(feature = "serde", serde(rename = "denoisingStrength"))]
|
||||
pub image_creativity: Option<f64>,
|
||||
#[cfg_attr(feature = "serde", serde(rename = "cfgScale"))]
|
||||
pub text_guidance: f64,
|
||||
#[cfg_attr(feature = "serde", serde(rename = "prompt"))]
|
||||
pub text_prompt: String,
|
||||
#[cfg_attr(feature = "serde", serde(rename = "negativePrompt"))]
|
||||
pub negative_prompt: String,
|
||||
pub resolution: (u32, u32),
|
||||
#[cfg_attr(feature = "serde", serde(rename = "restoreFaces"))]
|
||||
pub restore_faces: bool,
|
||||
pub tiling: bool,
|
||||
pub struct ImaginatePreferences {
|
||||
pub host_name: String,
|
||||
}
|
||||
|
||||
impl graphene_core::application_io::GetImaginatePreferences for ImaginatePreferences {
|
||||
fn get_host_name(&self) -> &str {
|
||||
&self.host_name
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ImaginatePreferences {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host_name: "http://localhost:7860/".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl dyn_any::StaticType for ImaginatePreferences {
|
||||
type Static = ImaginatePreferences;
|
||||
}
|
||||
|
|
|
@ -9,12 +9,18 @@ license = "MIT OR Apache-2.0"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
default = ["wasm"]
|
||||
gpu = ["graphene-core/gpu", "gpu-compiler-bin-wrapper", "compilation-client", "gpu-executor"]
|
||||
default = ["wasm", "imaginate"]
|
||||
gpu = [
|
||||
"graphene-core/gpu",
|
||||
"gpu-compiler-bin-wrapper",
|
||||
"compilation-client",
|
||||
"gpu-executor",
|
||||
]
|
||||
vulkan = ["gpu", "vulkan-executor"]
|
||||
wgpu = ["gpu", "wgpu-executor"]
|
||||
quantization = ["autoquant"]
|
||||
wasm = ["wasm-bindgen", "web-sys", "js-sys"]
|
||||
imaginate = ["image/png", "base64", "js-sys", "web-sys", "wasm-bindgen-futures"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
|
@ -37,6 +43,7 @@ compilation-client = { path = "../compilation-client", optional = true }
|
|||
bytemuck = { version = "1.8" }
|
||||
tempfile = "3"
|
||||
image = { version = "*", default-features = false }
|
||||
base64 = { version = "0.21", optional = true }
|
||||
dyn-clone = "1.0"
|
||||
|
||||
log = "0.4"
|
||||
|
@ -48,12 +55,13 @@ glam = { version = "0.22", features = ["serde"] }
|
|||
node-macro = { path = "../node-macro" }
|
||||
xxhash-rust = { workspace = true }
|
||||
serde_json = "1.0.96"
|
||||
reqwest = { version = "0.11.17", features = ["rustls", "rustls-tls"] }
|
||||
reqwest = { version = "0.11.18", features = ["rustls", "rustls-tls", "json"] }
|
||||
futures = "0.3.28"
|
||||
wasm-bindgen = { version = "0.2.84", optional = true }
|
||||
js-sys = { version = "0.3.55", optional = true }
|
||||
js-sys = { version = "0.3.63", optional = true }
|
||||
wgpu-types = "0.16.0"
|
||||
wgpu = "0.16.1"
|
||||
wasm-bindgen-futures = { version = "0.4.36", optional = true }
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1.0"
|
||||
|
@ -62,7 +70,7 @@ features = ["derive"]
|
|||
|
||||
|
||||
[dependencies.web-sys]
|
||||
version = "0.3.4"
|
||||
version = "0.3.63"
|
||||
optional = true
|
||||
features = [
|
||||
"Window",
|
||||
|
|
|
@ -76,6 +76,7 @@ impl<_I, _O, S0> DynAnyRefNode<_I, _O, S0> {
|
|||
Self { node, _i: core::marker::PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DynAnyInRefNode<I, O, Node> {
|
||||
node: Node,
|
||||
_i: PhantomData<(I, O)>,
|
||||
|
@ -115,6 +116,10 @@ where
|
|||
fn reset(&self) {
|
||||
self.node.reset();
|
||||
}
|
||||
|
||||
fn serialize(&self) -> Option<std::sync::Arc<dyn core::any::Any>> {
|
||||
self.node.serialize()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N> FutureWrapperNode<N> {
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use std::future::Future;
|
||||
|
||||
use crate::Node;
|
||||
|
||||
pub struct GetNode;
|
||||
|
@ -17,16 +15,3 @@ pub struct PostNode<Body> {
|
|||
async fn post_node(url: String, body: String) -> reqwest::Response {
|
||||
reqwest::Client::new().post(url).body(body).send().await.unwrap()
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct EvalSyncNode {}
|
||||
|
||||
#[node_macro::node_fn(EvalSyncNode)]
|
||||
fn eval_sync<F: Future + 'input>(future: F) -> F::Output {
|
||||
let future = futures::future::maybe_done(future);
|
||||
futures::pin_mut!(future);
|
||||
match future.as_mut().take_output() {
|
||||
Some(value) => value,
|
||||
_ => panic!("Node construction future returned pending"),
|
||||
}
|
||||
}
|
||||
|
|
517
node-graph/gstd/src/imaginate.rs
Normal file
517
node-graph/gstd/src/imaginate.rs
Normal file
|
@ -0,0 +1,517 @@
|
|||
use crate::wasm_application_io::WasmEditorApi;
|
||||
use core::any::TypeId;
|
||||
use core::future::Future;
|
||||
use futures::{future::Either, TryFutureExt};
|
||||
use glam::DVec2;
|
||||
use graph_craft::imaginate_input::{ImaginateController, ImaginateMaskStartingFill, ImaginatePreferences, ImaginateSamplingMethod, ImaginateServerStatus, ImaginateStatus, ImaginateTerminationHandle};
|
||||
use graphene_core::application_io::NodeGraphUpdateMessage;
|
||||
use graphene_core::raster::{Color, Image, Luma, Pixel};
|
||||
use image::{DynamicImage, ImageBuffer, ImageOutputFormat};
|
||||
use reqwest::Url;
|
||||
|
||||
const PROGRESS_EVERY_N_STEPS: u32 = 5;
|
||||
const SDAPI_TEXT_TO_IMAGE: &str = "sdapi/v1/txt2img";
|
||||
const SDAPI_IMAGE_TO_IMAGE: &str = "sdapi/v1/img2img";
|
||||
const SDAPI_PROGRESS: &str = "sdapi/v1/progress?skip_current_image=true";
|
||||
const SDAPI_TERMINATE: &str = "sdapi/v1/interrupt";
|
||||
|
||||
fn new_client() -> Result<reqwest::Client, Error> {
|
||||
reqwest::ClientBuilder::new().build().map_err(Error::ClientBuild)
|
||||
}
|
||||
|
||||
fn parse_url(url: &str) -> Result<Url, Error> {
|
||||
url.try_into().map_err(|err| Error::UrlParse { text: url.into(), err })
|
||||
}
|
||||
|
||||
fn join_url(base_url: &Url, path: &str) -> Result<Url, Error> {
|
||||
base_url.join(path).map_err(|err| Error::UrlParse { text: base_url.to_string(), err })
|
||||
}
|
||||
|
||||
fn new_get_request<U: reqwest::IntoUrl>(client: &reqwest::Client, url: U) -> Result<reqwest::Request, Error> {
|
||||
client.get(url).header("Accept", "*/*").build().map_err(Error::RequestBuild)
|
||||
}
|
||||
|
||||
pub struct ImaginatePersistentData {
|
||||
pending_server_check: Option<futures::channel::oneshot::Receiver<reqwest::Result<reqwest::Response>>>,
|
||||
host_name: Url,
|
||||
client: Option<reqwest::Client>,
|
||||
server_status: ImaginateServerStatus,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for ImaginatePersistentData {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
f.debug_struct(core::any::type_name::<Self>())
|
||||
.field("pending_server_check", &self.pending_server_check.is_some())
|
||||
.field("host_name", &self.host_name)
|
||||
.field("status", &self.server_status)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ImaginatePersistentData {
|
||||
fn default() -> Self {
|
||||
let mut status = ImaginateServerStatus::default();
|
||||
let client = new_client().map_err(|err| status = ImaginateServerStatus::Failed(err.to_string())).ok();
|
||||
let ImaginatePreferences { host_name } = Default::default();
|
||||
Self {
|
||||
pending_server_check: None,
|
||||
host_name: parse_url(&host_name).unwrap(),
|
||||
client,
|
||||
server_status: status,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ImaginatePersistentData {
|
||||
pub fn set_host_name(&mut self, name: &str) {
|
||||
match parse_url(name) {
|
||||
Ok(url) => self.host_name = url,
|
||||
Err(err) => self.server_status = ImaginateServerStatus::Failed(err.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn initiate_server_check_maybe_fail(&mut self) -> Result<Option<core::pin::Pin<Box<dyn Future<Output = ()> + 'static>>>, Error> {
|
||||
use futures::future::FutureExt;
|
||||
let Some(client) = &self.client else { return Ok(None); };
|
||||
if self.pending_server_check.is_some() {
|
||||
return Ok(None);
|
||||
}
|
||||
self.server_status = ImaginateServerStatus::Checking;
|
||||
let url = join_url(&self.host_name, SDAPI_PROGRESS)?;
|
||||
let request = new_get_request(client, url)?;
|
||||
let (send, recv) = futures::channel::oneshot::channel();
|
||||
let response_future = client.execute(request).map(move |r| {
|
||||
let _ = send.send(r);
|
||||
});
|
||||
self.pending_server_check = Some(recv);
|
||||
Ok(Some(Box::pin(response_future)))
|
||||
}
|
||||
|
||||
pub fn initiate_server_check(&mut self) -> Option<core::pin::Pin<Box<dyn Future<Output = ()> + 'static>>> {
|
||||
match self.initiate_server_check_maybe_fail() {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
self.server_status = ImaginateServerStatus::Failed(err.to_string());
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_server_check(&mut self) {
|
||||
if let Some(mut check) = self.pending_server_check.take() {
|
||||
self.server_status = match check.try_recv().map(|r| r.map(|r| r.and_then(reqwest::Response::error_for_status))) {
|
||||
Ok(Some(Ok(_response))) => ImaginateServerStatus::Connected,
|
||||
Ok(Some(Err(_))) | Err(_) => ImaginateServerStatus::Unavailable,
|
||||
Ok(None) => {
|
||||
self.pending_server_check = Some(check);
|
||||
ImaginateServerStatus::Checking
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn server_status(&self) -> &ImaginateServerStatus {
|
||||
&self.server_status
|
||||
}
|
||||
|
||||
pub fn is_checking(&self) -> bool {
|
||||
matches!(self.server_status, ImaginateServerStatus::Checking)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ImaginateFutureAbortHandle(futures::future::AbortHandle);
|
||||
|
||||
impl ImaginateTerminationHandle for ImaginateFutureAbortHandle {
|
||||
fn terminate(&self) {
|
||||
self.0.abort()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Error {
|
||||
UrlParse { text: String, err: <&'static str as TryInto<Url>>::Error },
|
||||
ClientBuild(reqwest::Error),
|
||||
RequestBuild(reqwest::Error),
|
||||
Request(reqwest::Error),
|
||||
ResponseFormat(reqwest::Error),
|
||||
NoImage,
|
||||
Base64Decode(base64::DecodeError),
|
||||
ImageDecode(image::error::ImageError),
|
||||
ImageEncode(image::error::ImageError),
|
||||
UnsupportedPixelType(&'static str),
|
||||
InconsistentImageSize,
|
||||
Terminated,
|
||||
TerminationFailed(reqwest::Error),
|
||||
}
|
||||
|
||||
impl core::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
match self {
|
||||
Self::UrlParse { text, err } => write!(f, "invalid url '{text}' ({err})"),
|
||||
Self::ClientBuild(err) => write!(f, "failed to create a reqwest client ({err})"),
|
||||
Self::RequestBuild(err) => write!(f, "failed to create a reqwest request ({err})"),
|
||||
Self::Request(err) => write!(f, "request failed ({err})"),
|
||||
Self::ResponseFormat(err) => write!(f, "got an invalid API response ({err})"),
|
||||
Self::NoImage => write!(f, "got an empty API response"),
|
||||
Self::Base64Decode(err) => write!(f, "failed to decode base64 encoded image ({err})"),
|
||||
Self::ImageDecode(err) => write!(f, "failed to decode png image ({err})"),
|
||||
Self::ImageEncode(err) => write!(f, "failed to encode png image ({err})"),
|
||||
Self::UnsupportedPixelType(ty) => write!(f, "pixel type `{ty}` not supported for imaginate images"),
|
||||
Self::InconsistentImageSize => write!(f, "image width and height do not match the image byte size"),
|
||||
Self::Terminated => write!(f, "imaginate request was terminated by the user"),
|
||||
Self::TerminationFailed(err) => write!(f, "termination failed ({err})"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
|
||||
struct ImageResponse {
|
||||
images: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
|
||||
struct ProgressResponse {
|
||||
progress: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateTextToImageRequestOverrideSettings {
|
||||
show_progress_every_n_steps: u32,
|
||||
}
|
||||
|
||||
impl Default for ImaginateTextToImageRequestOverrideSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateImageToImageRequestOverrideSettings {
|
||||
show_progress_every_n_steps: u32,
|
||||
img2img_fix_steps: bool,
|
||||
}
|
||||
|
||||
impl Default for ImaginateImageToImageRequestOverrideSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
|
||||
img2img_fix_steps: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateTextToImageRequest<'a> {
|
||||
#[serde(flatten)]
|
||||
common: ImaginateCommonImageRequest<'a>,
|
||||
override_settings: ImaginateTextToImageRequestOverrideSettings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateMask {
|
||||
mask: String,
|
||||
mask_blur: String,
|
||||
inpainting_fill: u32,
|
||||
inpaint_full_res: bool,
|
||||
inpainting_mask_invert: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateImageToImageRequest<'a> {
|
||||
#[serde(flatten)]
|
||||
common: ImaginateCommonImageRequest<'a>,
|
||||
override_settings: ImaginateImageToImageRequestOverrideSettings,
|
||||
|
||||
init_images: Vec<String>,
|
||||
denoising_strength: f64,
|
||||
#[serde(flatten)]
|
||||
mask: Option<ImaginateMask>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateCommonImageRequest<'a> {
|
||||
prompt: String,
|
||||
seed: f64,
|
||||
steps: u32,
|
||||
cfg_scale: f64,
|
||||
width: f64,
|
||||
height: f64,
|
||||
restore_faces: bool,
|
||||
tiling: bool,
|
||||
negative_prompt: String,
|
||||
sampler_index: &'a str,
|
||||
}
|
||||
|
||||
#[cfg(feature = "imaginate")]
|
||||
pub async fn imaginate<'a, P: Pixel>(
|
||||
image: Image<P>,
|
||||
editor_api: impl Future<Output = WasmEditorApi<'a>>,
|
||||
controller: ImaginateController,
|
||||
seed: impl Future<Output = f64>,
|
||||
res: impl Future<Output = Option<DVec2>>,
|
||||
samples: impl Future<Output = u32>,
|
||||
sampling_method: impl Future<Output = ImaginateSamplingMethod>,
|
||||
prompt_guidance: impl Future<Output = f64>,
|
||||
prompt: impl Future<Output = String>,
|
||||
negative_prompt: impl Future<Output = String>,
|
||||
adapt_input_image: impl Future<Output = bool>,
|
||||
image_creativity: impl Future<Output = f64>,
|
||||
masking_layer: impl Future<Output = Option<Vec<u64>>>,
|
||||
inpaint: impl Future<Output = bool>,
|
||||
mask_blur: impl Future<Output = f64>,
|
||||
mask_starting_fill: impl Future<Output = ImaginateMaskStartingFill>,
|
||||
improve_faces: impl Future<Output = bool>,
|
||||
tiling: impl Future<Output = bool>,
|
||||
) -> Image<P> {
|
||||
let WasmEditorApi {
|
||||
node_graph_message_sender,
|
||||
imaginate_preferences,
|
||||
..
|
||||
} = editor_api.await;
|
||||
let set_progress = |progress: ImaginateStatus| {
|
||||
controller.set_status(progress);
|
||||
node_graph_message_sender.send(NodeGraphUpdateMessage::ImaginateStatusUpdate);
|
||||
};
|
||||
let host_name = imaginate_preferences.get_host_name();
|
||||
imaginate_maybe_fail(
|
||||
image,
|
||||
host_name,
|
||||
set_progress,
|
||||
&controller,
|
||||
seed,
|
||||
res,
|
||||
samples,
|
||||
sampling_method,
|
||||
prompt_guidance,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
adapt_input_image,
|
||||
image_creativity,
|
||||
masking_layer,
|
||||
inpaint,
|
||||
mask_blur,
|
||||
mask_starting_fill,
|
||||
improve_faces,
|
||||
tiling,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
match err {
|
||||
Error::Terminated => {
|
||||
set_progress(ImaginateStatus::Terminated);
|
||||
}
|
||||
err => {
|
||||
error!("{err}");
|
||||
set_progress(ImaginateStatus::Failed(err.to_string()));
|
||||
}
|
||||
};
|
||||
Image::empty()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "imaginate")]
|
||||
async fn imaginate_maybe_fail<'a, P: Pixel, F: Fn(ImaginateStatus)>(
|
||||
image: Image<P>,
|
||||
host_name: &str,
|
||||
set_progress: F,
|
||||
controller: &ImaginateController,
|
||||
seed: impl Future<Output = f64>,
|
||||
res: impl Future<Output = Option<DVec2>>,
|
||||
samples: impl Future<Output = u32>,
|
||||
sampling_method: impl Future<Output = ImaginateSamplingMethod>,
|
||||
prompt_guidance: impl Future<Output = f64>,
|
||||
prompt: impl Future<Output = String>,
|
||||
negative_prompt: impl Future<Output = String>,
|
||||
adapt_input_image: impl Future<Output = bool>,
|
||||
image_creativity: impl Future<Output = f64>,
|
||||
_masking_layer: impl Future<Output = Option<Vec<u64>>>,
|
||||
_inpaint: impl Future<Output = bool>,
|
||||
_mask_blur: impl Future<Output = f64>,
|
||||
_mask_starting_fill: impl Future<Output = ImaginateMaskStartingFill>,
|
||||
improve_faces: impl Future<Output = bool>,
|
||||
tiling: impl Future<Output = bool>,
|
||||
) -> Result<Image<P>, Error> {
|
||||
set_progress(ImaginateStatus::Beginning);
|
||||
|
||||
let base_url: Url = parse_url(host_name)?;
|
||||
|
||||
let client = new_client()?;
|
||||
|
||||
let sampler_index = sampling_method.await;
|
||||
let sampler_index = sampler_index.api_value();
|
||||
|
||||
let res = res.await.unwrap_or_else(|| {
|
||||
let (width, height) = pick_safe_imaginate_resolution((image.width as _, image.height as _));
|
||||
DVec2::new(width as _, height as _)
|
||||
});
|
||||
let common_request_data = ImaginateCommonImageRequest {
|
||||
prompt: prompt.await,
|
||||
seed: seed.await,
|
||||
steps: samples.await,
|
||||
cfg_scale: prompt_guidance.await,
|
||||
width: res.x,
|
||||
height: res.y,
|
||||
restore_faces: improve_faces.await,
|
||||
tiling: tiling.await,
|
||||
negative_prompt: negative_prompt.await,
|
||||
sampler_index,
|
||||
};
|
||||
let request_builder = if adapt_input_image.await {
|
||||
let base64_data = image_to_base64(image)?;
|
||||
let request_data = ImaginateImageToImageRequest {
|
||||
common: common_request_data,
|
||||
override_settings: Default::default(),
|
||||
|
||||
init_images: vec![base64_data],
|
||||
denoising_strength: image_creativity.await * 0.01,
|
||||
mask: None,
|
||||
};
|
||||
let url = join_url(&base_url, SDAPI_IMAGE_TO_IMAGE)?;
|
||||
client.post(url).json(&request_data)
|
||||
} else {
|
||||
let request_data = ImaginateTextToImageRequest {
|
||||
common: common_request_data,
|
||||
override_settings: Default::default(),
|
||||
};
|
||||
let url = join_url(&base_url, SDAPI_TEXT_TO_IMAGE)?;
|
||||
client.post(url).json(&request_data)
|
||||
};
|
||||
|
||||
let request = request_builder.header("Accept", "*/*").build().map_err(Error::RequestBuild)?;
|
||||
|
||||
let (response_future, abort_handle) = futures::future::abortable(client.execute(request));
|
||||
controller.set_termination_handle(Box::new(ImaginateFutureAbortHandle(abort_handle)));
|
||||
|
||||
let progress_url = join_url(&base_url, SDAPI_PROGRESS)?;
|
||||
|
||||
futures::pin_mut!(response_future);
|
||||
|
||||
let response = loop {
|
||||
let progress_request = new_get_request(&client, progress_url.clone())?;
|
||||
let progress_response_future = client.execute(progress_request).and_then(|response| response.json());
|
||||
|
||||
futures::pin_mut!(progress_response_future);
|
||||
|
||||
response_future = match futures::future::select(response_future, progress_response_future).await {
|
||||
Either::Left((response, _)) => break response,
|
||||
Either::Right((progress, response_future)) => {
|
||||
if let Ok(ProgressResponse { progress }) = progress {
|
||||
set_progress(ImaginateStatus::Generating(progress * 100.));
|
||||
}
|
||||
response_future
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
let response = match response {
|
||||
Ok(response) => response.and_then(reqwest::Response::error_for_status).map_err(Error::Request)?,
|
||||
Err(_aborted) => {
|
||||
set_progress(ImaginateStatus::Terminating);
|
||||
let url = join_url(&base_url, SDAPI_TERMINATE)?;
|
||||
let request = client.post(url).build().map_err(Error::RequestBuild)?;
|
||||
// The user probably doesn't really care if the server side was really aborted or if there was an network error.
|
||||
// So we fool them that the request was terminated if the termination request in reality failed.
|
||||
let _ = client.execute(request).await.and_then(reqwest::Response::error_for_status).map_err(Error::TerminationFailed)?;
|
||||
return Err(Error::Terminated);
|
||||
}
|
||||
};
|
||||
|
||||
set_progress(ImaginateStatus::Uploading);
|
||||
|
||||
let ImageResponse { images } = response.json().await.map_err(Error::ResponseFormat)?;
|
||||
|
||||
let result = images.into_iter().next().ok_or(Error::NoImage).and_then(base64_to_image)?;
|
||||
|
||||
set_progress(ImaginateStatus::ReadyDone);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn image_to_base64<P: Pixel>(image: Image<P>) -> Result<String, Error> {
|
||||
use base64::prelude::*;
|
||||
|
||||
let Image { width, height, data } = image;
|
||||
|
||||
fn cast_with_f32<S: Pixel, D: image::Pixel<Subpixel = f32>>(data: Vec<S>, width: u32, height: u32) -> Result<DynamicImage, Error>
|
||||
where
|
||||
DynamicImage: From<ImageBuffer<D, Vec<f32>>>,
|
||||
{
|
||||
ImageBuffer::<D, Vec<f32>>::from_raw(width, height, bytemuck::cast_vec(data))
|
||||
.ok_or(Error::InconsistentImageSize)
|
||||
.map(Into::into)
|
||||
}
|
||||
|
||||
let image: DynamicImage = match TypeId::of::<P>() {
|
||||
id if id == TypeId::of::<Color>() => cast_with_f32::<_, image::Rgba<f32>>(data, width, height)?
|
||||
// we need to do this cast, because png does not support rgba32f
|
||||
.to_rgba16().into(),
|
||||
id if id == TypeId::of::<Luma>() => cast_with_f32::<_, image::Luma<f32>>(data, width, height)?
|
||||
// we need to do this cast, because png does not support luma32f
|
||||
.to_luma16().into(),
|
||||
_ => return Err(Error::UnsupportedPixelType(core::any::type_name::<P>())),
|
||||
};
|
||||
|
||||
let mut png_data = std::io::Cursor::new(vec![]);
|
||||
image.write_to(&mut png_data, ImageOutputFormat::Png).map_err(Error::ImageEncode)?;
|
||||
Ok(BASE64_STANDARD.encode(png_data.into_inner()))
|
||||
}
|
||||
|
||||
fn base64_to_image<D: AsRef<[u8]>, P: Pixel>(base64_data: D) -> Result<Image<P>, Error> {
|
||||
use base64::prelude::*;
|
||||
|
||||
let png_data = BASE64_STANDARD.decode(base64_data).map_err(Error::Base64Decode)?;
|
||||
let dyn_image = image::load_from_memory_with_format(&png_data, image::ImageFormat::Png).map_err(Error::ImageDecode)?;
|
||||
let (width, height) = (dyn_image.width(), dyn_image.height());
|
||||
|
||||
let result_data: Vec<P> = match TypeId::of::<P>() {
|
||||
id if id == TypeId::of::<Color>() => bytemuck::cast_vec(dyn_image.into_rgba32f().into_raw()),
|
||||
id if id == TypeId::of::<Luma>() => bytemuck::cast_vec(dyn_image.to_luma32f().into_raw()),
|
||||
_ => return Err(Error::UnsupportedPixelType(core::any::type_name::<P>())),
|
||||
};
|
||||
|
||||
Ok(Image { data: result_data, width, height })
|
||||
}
|
||||
|
||||
pub fn pick_safe_imaginate_resolution((width, height): (f64, f64)) -> (u64, u64) {
|
||||
const MAX_RESOLUTION: u64 = 1000 * 1000;
|
||||
|
||||
// this is the maximum width/height that can be obtained
|
||||
const MAX_DIMENSION: u64 = (MAX_RESOLUTION / 64) & !63;
|
||||
|
||||
// round the resolution to the nearest multiple of 64
|
||||
let [width, height] = [width, height].map(|c| (c.round().clamp(0., MAX_DIMENSION as _) as u64 + 32).max(64) & !63);
|
||||
let resolution = width * height;
|
||||
|
||||
if resolution > MAX_RESOLUTION {
|
||||
// scale down the image, so it is smaller than MAX_RESOLUTION
|
||||
let scale = (MAX_RESOLUTION as f64 / resolution as f64).sqrt();
|
||||
let [width, height] = [width, height].map(|c| c as f64 * scale);
|
||||
|
||||
if width < 64.0 {
|
||||
// the image is extremely wide
|
||||
(64, MAX_DIMENSION)
|
||||
} else if height < 64.0 {
|
||||
// the image is extremely high
|
||||
(MAX_DIMENSION, 64)
|
||||
} else {
|
||||
// round down to a multiple of 64, so that the resolution still is smaller than MAX_RESOLUTION
|
||||
let [width, height] = [width, height].map(|c| c as u64 & !63);
|
||||
(width, height)
|
||||
}
|
||||
} else {
|
||||
(width, height)
|
||||
}
|
||||
}
|
|
@ -25,3 +25,5 @@ pub mod brush;
|
|||
|
||||
#[cfg(feature = "wasm")]
|
||||
pub mod wasm_application_io;
|
||||
|
||||
pub mod imaginate;
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
use dyn_any::{DynAny, StaticType};
|
||||
use glam::{DAffine2, DVec2};
|
||||
use graph_craft::imaginate_input::{ImaginateController, ImaginateMaskStartingFill, ImaginateSamplingMethod};
|
||||
use graph_craft::proto::DynFuture;
|
||||
use graphene_core::raster::{Alpha, BlendMode, BlendNode, Image, ImageFrame, Linear, LinearChannel, Luminance, Pixel, RGBMut, Raster, RasterMut, RedGreenBlue, Sample};
|
||||
use graphene_core::transform::Transform;
|
||||
|
||||
use crate::wasm_application_io::WasmEditorApi;
|
||||
use graphene_core::raster::bbox::{AxisAlignedBbox, Bbox};
|
||||
use graphene_core::value::CopiedNode;
|
||||
use graphene_core::{Color, Node};
|
||||
|
@ -414,19 +417,74 @@ fn empty_image<_P: Pixel>(transform: DAffine2, color: _P) -> ImageFrame<_P> {
|
|||
ImageFrame { image, transform }
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ImaginateNode<P, E> {
|
||||
cached: E,
|
||||
_p: PhantomData<P>,
|
||||
macro_rules! generate_imaginate_node {
|
||||
($($val:ident: $t:ident: $o:ty,)*) => {
|
||||
pub struct ImaginateNode<P: Pixel, E, C, $($t,)*> {
|
||||
editor_api: E,
|
||||
controller: C,
|
||||
$($val: $t,)*
|
||||
cache: std::sync::Arc<std::sync::Mutex<Image<P>>>,
|
||||
}
|
||||
|
||||
impl<'e, P: Pixel, E, C, $($t,)*> ImaginateNode<P, E, C, $($t,)*>
|
||||
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
|
||||
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, WasmEditorApi<'e>>>,
|
||||
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
|
||||
{
|
||||
pub fn new(editor_api: E, controller: C, $($val: $t,)* cache: std::sync::Arc<std::sync::Mutex<Image<P>>>) -> Self {
|
||||
Self { editor_api, controller, $($val,)* cache }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'i, 'e: 'i, P: Pixel + 'i, E: 'i, C: 'i, $($t: 'i,)*> Node<'i, ImageFrame<P>> for ImaginateNode<P, E, C, $($t,)*>
|
||||
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
|
||||
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, WasmEditorApi<'e>>>,
|
||||
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
|
||||
{
|
||||
type Output = DynFuture<'i, ImageFrame<P>>;
|
||||
|
||||
fn eval(&'i self, frame: ImageFrame<P>) -> Self::Output {
|
||||
let controller = self.controller.eval(());
|
||||
$(let $val = self.$val.eval(());)*
|
||||
Box::pin(async move {
|
||||
let controller: std::pin::Pin<Box<dyn std::future::Future<Output = ImaginateController>>> = controller;
|
||||
let controller: ImaginateController = controller.await;
|
||||
if controller.take_regenerate_trigger() {
|
||||
let editor_api = self.editor_api.eval(());
|
||||
let image = super::imaginate::imaginate(frame.image, editor_api, controller, $($val,)*).await;
|
||||
self.cache.lock().unwrap().clone_from(&image);
|
||||
return ImageFrame {
|
||||
image,
|
||||
..frame
|
||||
}
|
||||
}
|
||||
let image = self.cache.lock().unwrap().clone();
|
||||
ImageFrame {
|
||||
image,
|
||||
..frame
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(ImaginateNode<_P>)]
|
||||
fn imaginate<_P: Pixel>(image_frame: ImageFrame<_P>, cached: Option<std::sync::Arc<graphene_core::raster::Image<_P>>>) -> ImageFrame<_P> {
|
||||
let cached_image = cached.map(|mut x| std::sync::Arc::make_mut(&mut x).clone()).unwrap_or(image_frame.image);
|
||||
ImageFrame {
|
||||
image: cached_image,
|
||||
transform: image_frame.transform,
|
||||
}
|
||||
generate_imaginate_node! {
|
||||
seed: Seed: f64,
|
||||
res: Res: Option<DVec2>,
|
||||
samples: Samples: u32,
|
||||
sampling_method: SamplingMethod: ImaginateSamplingMethod,
|
||||
prompt_guidance: PromptGuidance: f64,
|
||||
prompt: Prompt: String,
|
||||
negative_prompt: NegativePrompt: String,
|
||||
adapt_input_image: AdaptInputImage: bool,
|
||||
image_creativity: ImageCreativity: f64,
|
||||
masking_layer: MaskingLayer: Option<Vec<u64>>,
|
||||
inpaint: Inpaint: bool,
|
||||
mask_blur: MaskBlur: f64,
|
||||
mask_starting_fill: MaskStartingFill: ImaginateMaskStartingFill,
|
||||
improve_faces: ImproveFaces: bool,
|
||||
tiling: Tiling: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
|
|
@ -3,7 +3,6 @@ pub mod node_registry;
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use graph_craft::document::value::TaggedValue;
|
||||
use graphene_core::*;
|
||||
use std::borrow::Cow;
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use graph_craft::imaginate_input::{ImaginateCache, ImaginateController, ImaginateMaskStartingFill, ImaginateSamplingMethod};
|
||||
use graph_craft::proto::{NodeConstructor, TypeErasedBox};
|
||||
use graphene_core::ops::IdNode;
|
||||
use graphene_core::quantization::QuantizationChannels;
|
||||
|
@ -444,18 +445,59 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
|
|||
params: [WasmSurfaceHandleFrame]
|
||||
),
|
||||
async_node!(graphene_core::memo::EndLetNode<_>, input: WasmEditorApi, output: SurfaceFrame, params: [SurfaceFrame]),
|
||||
vec![(
|
||||
NodeIdentifier::new("graphene_core::memo::RefNode<_, _>"),
|
||||
|args| {
|
||||
Box::pin(async move {
|
||||
let node: DowncastBothNode<Option<WasmEditorApi>, WasmEditorApi> = graphene_std::any::DowncastBothNode::new(args[0].clone());
|
||||
let node = <graphene_core::memo::RefNode<_, _>>::new(node);
|
||||
let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
|
||||
any.into_type_erased()
|
||||
})
|
||||
},
|
||||
NodeIOTypes::new(concrete!(()), concrete!(WasmEditorApi), vec![fn_type!(Option<WasmEditorApi>, WasmEditorApi)]),
|
||||
)],
|
||||
vec![
|
||||
(
|
||||
NodeIdentifier::new("graphene_core::memo::RefNode<_, _>"),
|
||||
|args| {
|
||||
Box::pin(async move {
|
||||
let node: DowncastBothNode<Option<WasmEditorApi>, WasmEditorApi> = graphene_std::any::DowncastBothNode::new(args[0].clone());
|
||||
let node = <graphene_core::memo::RefNode<_, _>>::new(node);
|
||||
let any: DynAnyNode<(), _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
|
||||
any.into_type_erased()
|
||||
})
|
||||
},
|
||||
NodeIOTypes::new(concrete!(()), concrete!(WasmEditorApi), vec![fn_type!(Option<WasmEditorApi>, WasmEditorApi)]),
|
||||
),
|
||||
(
|
||||
NodeIdentifier::new("graphene_std::raster::ImaginateNode<_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _>"),
|
||||
|args: Vec<Arc<graph_craft::proto::NodeContainer>>| {
|
||||
Box::pin(async move {
|
||||
use graphene_std::raster::ImaginateNode;
|
||||
let cache: ImaginateCache = graphene_std::any::input_node(args.last().unwrap().clone()).eval(()).await;
|
||||
macro_rules! instanciate_imaginate_node {
|
||||
($($i:expr,)*) => { ImaginateNode::new($(graphene_std::any::input_node(args[$i].clone()),)* cache.into_inner()) };
|
||||
}
|
||||
let node: ImaginateNode<Color, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _> = instanciate_imaginate_node!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,);
|
||||
let any = graphene_std::any::DynAnyNode::new(ValueNode::new(node));
|
||||
any.into_type_erased()
|
||||
})
|
||||
},
|
||||
NodeIOTypes::new(
|
||||
concrete!(ImageFrame<Color>),
|
||||
concrete!(ImageFrame<Color>),
|
||||
vec![
|
||||
fn_type!(WasmEditorApi),
|
||||
fn_type!(ImaginateController),
|
||||
fn_type!(f64),
|
||||
fn_type!(Option<DVec2>),
|
||||
fn_type!(u32),
|
||||
fn_type!(ImaginateSamplingMethod),
|
||||
fn_type!(f64),
|
||||
fn_type!(String),
|
||||
fn_type!(String),
|
||||
fn_type!(bool),
|
||||
fn_type!(f64),
|
||||
fn_type!(Option<Vec<u64>>),
|
||||
fn_type!(bool),
|
||||
fn_type!(f64),
|
||||
fn_type!(ImaginateMaskStartingFill),
|
||||
fn_type!(bool),
|
||||
fn_type!(bool),
|
||||
fn_type!(ImaginateCache),
|
||||
],
|
||||
),
|
||||
),
|
||||
],
|
||||
async_node!(graphene_core::memo::MemoNode<_, _>, input: (), output: Image<Color>, params: [Image<Color>]),
|
||||
async_node!(graphene_core::memo::MemoNode<_, _>, input: (), output: ImageFrame<Color>, params: [ImageFrame<Color>]),
|
||||
async_node!(graphene_core::memo::MemoNode<_, _>, input: (), output: QuantizationChannels, params: [QuantizationChannels]),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue