mirror of
https://github.com/GraphiteEditor/Graphite.git
synced 2025-08-08 07:18:01 +00:00
Fix Imaginate by porting its JS roundtrip code to graph-based async execution in Rust (#1250)
* Create asynchronous rust imaginate node * Make a first imaginate request via rust * Implement parsing of imaginate API result image * Stop refresh timer from affecting imaginate progress requests * Add cargo-about clarification for rustls-webpki * Delete imaginate.ts and all uses of its functions * Add imaginate img2img feature * Fix imaginate random seed button * Fix imaginate ui inferring non-custom resolutions * Fix the imaginate progress indicator * Remove ImaginatePreferences from being compiled into node graph * Regenerate imaginate only when hitting button * Add ability to terminate imaginate requests * Add imaginate server check feature * Do not compile wasm_bindgen bindings in graphite_editor for tests * Address some review suggestions - move wasm futures dependency in editor to the future-executor crate - guard wasm-bindgen in editor behind a `wasm` feature flag - dont make seed number input a slider - remove poll_server_check from process_message function beginning - guard wasm related code behind `cfg(target_arch = "wasm32")` instead of `cfg(test)` - Call the imaginate idle states "Ready" and "Done" instead of "Nothing to do" - Call the imaginate uploading state "Uploading Image" instead of "Uploading Input Image" - Remove the EvalSyncNode * Fix imaginate host name being restored between graphite instances also change the progress status texts a bit. --------- Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
parent
a1c70c4d90
commit
f76b850b9c
35 changed files with 1500 additions and 1326 deletions
|
@ -76,6 +76,7 @@ impl<_I, _O, S0> DynAnyRefNode<_I, _O, S0> {
|
|||
Self { node, _i: core::marker::PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DynAnyInRefNode<I, O, Node> {
|
||||
node: Node,
|
||||
_i: PhantomData<(I, O)>,
|
||||
|
@ -115,6 +116,10 @@ where
|
|||
fn reset(&self) {
|
||||
self.node.reset();
|
||||
}
|
||||
|
||||
fn serialize(&self) -> Option<std::sync::Arc<dyn core::any::Any>> {
|
||||
self.node.serialize()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N> FutureWrapperNode<N> {
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use std::future::Future;
|
||||
|
||||
use crate::Node;
|
||||
|
||||
pub struct GetNode;
|
||||
|
@ -17,16 +15,3 @@ pub struct PostNode<Body> {
|
|||
async fn post_node(url: String, body: String) -> reqwest::Response {
|
||||
reqwest::Client::new().post(url).body(body).send().await.unwrap()
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct EvalSyncNode {}
|
||||
|
||||
#[node_macro::node_fn(EvalSyncNode)]
|
||||
fn eval_sync<F: Future + 'input>(future: F) -> F::Output {
|
||||
let future = futures::future::maybe_done(future);
|
||||
futures::pin_mut!(future);
|
||||
match future.as_mut().take_output() {
|
||||
Some(value) => value,
|
||||
_ => panic!("Node construction future returned pending"),
|
||||
}
|
||||
}
|
||||
|
|
517
node-graph/gstd/src/imaginate.rs
Normal file
517
node-graph/gstd/src/imaginate.rs
Normal file
|
@ -0,0 +1,517 @@
|
|||
use crate::wasm_application_io::WasmEditorApi;
|
||||
use core::any::TypeId;
|
||||
use core::future::Future;
|
||||
use futures::{future::Either, TryFutureExt};
|
||||
use glam::DVec2;
|
||||
use graph_craft::imaginate_input::{ImaginateController, ImaginateMaskStartingFill, ImaginatePreferences, ImaginateSamplingMethod, ImaginateServerStatus, ImaginateStatus, ImaginateTerminationHandle};
|
||||
use graphene_core::application_io::NodeGraphUpdateMessage;
|
||||
use graphene_core::raster::{Color, Image, Luma, Pixel};
|
||||
use image::{DynamicImage, ImageBuffer, ImageOutputFormat};
|
||||
use reqwest::Url;
|
||||
|
||||
const PROGRESS_EVERY_N_STEPS: u32 = 5;
|
||||
const SDAPI_TEXT_TO_IMAGE: &str = "sdapi/v1/txt2img";
|
||||
const SDAPI_IMAGE_TO_IMAGE: &str = "sdapi/v1/img2img";
|
||||
const SDAPI_PROGRESS: &str = "sdapi/v1/progress?skip_current_image=true";
|
||||
const SDAPI_TERMINATE: &str = "sdapi/v1/interrupt";
|
||||
|
||||
fn new_client() -> Result<reqwest::Client, Error> {
|
||||
reqwest::ClientBuilder::new().build().map_err(Error::ClientBuild)
|
||||
}
|
||||
|
||||
fn parse_url(url: &str) -> Result<Url, Error> {
|
||||
url.try_into().map_err(|err| Error::UrlParse { text: url.into(), err })
|
||||
}
|
||||
|
||||
fn join_url(base_url: &Url, path: &str) -> Result<Url, Error> {
|
||||
base_url.join(path).map_err(|err| Error::UrlParse { text: base_url.to_string(), err })
|
||||
}
|
||||
|
||||
fn new_get_request<U: reqwest::IntoUrl>(client: &reqwest::Client, url: U) -> Result<reqwest::Request, Error> {
|
||||
client.get(url).header("Accept", "*/*").build().map_err(Error::RequestBuild)
|
||||
}
|
||||
|
||||
pub struct ImaginatePersistentData {
|
||||
pending_server_check: Option<futures::channel::oneshot::Receiver<reqwest::Result<reqwest::Response>>>,
|
||||
host_name: Url,
|
||||
client: Option<reqwest::Client>,
|
||||
server_status: ImaginateServerStatus,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for ImaginatePersistentData {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
f.debug_struct(core::any::type_name::<Self>())
|
||||
.field("pending_server_check", &self.pending_server_check.is_some())
|
||||
.field("host_name", &self.host_name)
|
||||
.field("status", &self.server_status)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ImaginatePersistentData {
|
||||
fn default() -> Self {
|
||||
let mut status = ImaginateServerStatus::default();
|
||||
let client = new_client().map_err(|err| status = ImaginateServerStatus::Failed(err.to_string())).ok();
|
||||
let ImaginatePreferences { host_name } = Default::default();
|
||||
Self {
|
||||
pending_server_check: None,
|
||||
host_name: parse_url(&host_name).unwrap(),
|
||||
client,
|
||||
server_status: status,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ImaginatePersistentData {
|
||||
pub fn set_host_name(&mut self, name: &str) {
|
||||
match parse_url(name) {
|
||||
Ok(url) => self.host_name = url,
|
||||
Err(err) => self.server_status = ImaginateServerStatus::Failed(err.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn initiate_server_check_maybe_fail(&mut self) -> Result<Option<core::pin::Pin<Box<dyn Future<Output = ()> + 'static>>>, Error> {
|
||||
use futures::future::FutureExt;
|
||||
let Some(client) = &self.client else { return Ok(None); };
|
||||
if self.pending_server_check.is_some() {
|
||||
return Ok(None);
|
||||
}
|
||||
self.server_status = ImaginateServerStatus::Checking;
|
||||
let url = join_url(&self.host_name, SDAPI_PROGRESS)?;
|
||||
let request = new_get_request(client, url)?;
|
||||
let (send, recv) = futures::channel::oneshot::channel();
|
||||
let response_future = client.execute(request).map(move |r| {
|
||||
let _ = send.send(r);
|
||||
});
|
||||
self.pending_server_check = Some(recv);
|
||||
Ok(Some(Box::pin(response_future)))
|
||||
}
|
||||
|
||||
pub fn initiate_server_check(&mut self) -> Option<core::pin::Pin<Box<dyn Future<Output = ()> + 'static>>> {
|
||||
match self.initiate_server_check_maybe_fail() {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
self.server_status = ImaginateServerStatus::Failed(err.to_string());
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_server_check(&mut self) {
|
||||
if let Some(mut check) = self.pending_server_check.take() {
|
||||
self.server_status = match check.try_recv().map(|r| r.map(|r| r.and_then(reqwest::Response::error_for_status))) {
|
||||
Ok(Some(Ok(_response))) => ImaginateServerStatus::Connected,
|
||||
Ok(Some(Err(_))) | Err(_) => ImaginateServerStatus::Unavailable,
|
||||
Ok(None) => {
|
||||
self.pending_server_check = Some(check);
|
||||
ImaginateServerStatus::Checking
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn server_status(&self) -> &ImaginateServerStatus {
|
||||
&self.server_status
|
||||
}
|
||||
|
||||
pub fn is_checking(&self) -> bool {
|
||||
matches!(self.server_status, ImaginateServerStatus::Checking)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ImaginateFutureAbortHandle(futures::future::AbortHandle);
|
||||
|
||||
impl ImaginateTerminationHandle for ImaginateFutureAbortHandle {
|
||||
fn terminate(&self) {
|
||||
self.0.abort()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Error {
|
||||
UrlParse { text: String, err: <&'static str as TryInto<Url>>::Error },
|
||||
ClientBuild(reqwest::Error),
|
||||
RequestBuild(reqwest::Error),
|
||||
Request(reqwest::Error),
|
||||
ResponseFormat(reqwest::Error),
|
||||
NoImage,
|
||||
Base64Decode(base64::DecodeError),
|
||||
ImageDecode(image::error::ImageError),
|
||||
ImageEncode(image::error::ImageError),
|
||||
UnsupportedPixelType(&'static str),
|
||||
InconsistentImageSize,
|
||||
Terminated,
|
||||
TerminationFailed(reqwest::Error),
|
||||
}
|
||||
|
||||
impl core::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
match self {
|
||||
Self::UrlParse { text, err } => write!(f, "invalid url '{text}' ({err})"),
|
||||
Self::ClientBuild(err) => write!(f, "failed to create a reqwest client ({err})"),
|
||||
Self::RequestBuild(err) => write!(f, "failed to create a reqwest request ({err})"),
|
||||
Self::Request(err) => write!(f, "request failed ({err})"),
|
||||
Self::ResponseFormat(err) => write!(f, "got an invalid API response ({err})"),
|
||||
Self::NoImage => write!(f, "got an empty API response"),
|
||||
Self::Base64Decode(err) => write!(f, "failed to decode base64 encoded image ({err})"),
|
||||
Self::ImageDecode(err) => write!(f, "failed to decode png image ({err})"),
|
||||
Self::ImageEncode(err) => write!(f, "failed to encode png image ({err})"),
|
||||
Self::UnsupportedPixelType(ty) => write!(f, "pixel type `{ty}` not supported for imaginate images"),
|
||||
Self::InconsistentImageSize => write!(f, "image width and height do not match the image byte size"),
|
||||
Self::Terminated => write!(f, "imaginate request was terminated by the user"),
|
||||
Self::TerminationFailed(err) => write!(f, "termination failed ({err})"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
|
||||
struct ImageResponse {
|
||||
images: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
|
||||
struct ProgressResponse {
|
||||
progress: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateTextToImageRequestOverrideSettings {
|
||||
show_progress_every_n_steps: u32,
|
||||
}
|
||||
|
||||
impl Default for ImaginateTextToImageRequestOverrideSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateImageToImageRequestOverrideSettings {
|
||||
show_progress_every_n_steps: u32,
|
||||
img2img_fix_steps: bool,
|
||||
}
|
||||
|
||||
impl Default for ImaginateImageToImageRequestOverrideSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
|
||||
img2img_fix_steps: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateTextToImageRequest<'a> {
|
||||
#[serde(flatten)]
|
||||
common: ImaginateCommonImageRequest<'a>,
|
||||
override_settings: ImaginateTextToImageRequestOverrideSettings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateMask {
|
||||
mask: String,
|
||||
mask_blur: String,
|
||||
inpainting_fill: u32,
|
||||
inpaint_full_res: bool,
|
||||
inpainting_mask_invert: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateImageToImageRequest<'a> {
|
||||
#[serde(flatten)]
|
||||
common: ImaginateCommonImageRequest<'a>,
|
||||
override_settings: ImaginateImageToImageRequestOverrideSettings,
|
||||
|
||||
init_images: Vec<String>,
|
||||
denoising_strength: f64,
|
||||
#[serde(flatten)]
|
||||
mask: Option<ImaginateMask>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
|
||||
struct ImaginateCommonImageRequest<'a> {
|
||||
prompt: String,
|
||||
seed: f64,
|
||||
steps: u32,
|
||||
cfg_scale: f64,
|
||||
width: f64,
|
||||
height: f64,
|
||||
restore_faces: bool,
|
||||
tiling: bool,
|
||||
negative_prompt: String,
|
||||
sampler_index: &'a str,
|
||||
}
|
||||
|
||||
#[cfg(feature = "imaginate")]
|
||||
pub async fn imaginate<'a, P: Pixel>(
|
||||
image: Image<P>,
|
||||
editor_api: impl Future<Output = WasmEditorApi<'a>>,
|
||||
controller: ImaginateController,
|
||||
seed: impl Future<Output = f64>,
|
||||
res: impl Future<Output = Option<DVec2>>,
|
||||
samples: impl Future<Output = u32>,
|
||||
sampling_method: impl Future<Output = ImaginateSamplingMethod>,
|
||||
prompt_guidance: impl Future<Output = f64>,
|
||||
prompt: impl Future<Output = String>,
|
||||
negative_prompt: impl Future<Output = String>,
|
||||
adapt_input_image: impl Future<Output = bool>,
|
||||
image_creativity: impl Future<Output = f64>,
|
||||
masking_layer: impl Future<Output = Option<Vec<u64>>>,
|
||||
inpaint: impl Future<Output = bool>,
|
||||
mask_blur: impl Future<Output = f64>,
|
||||
mask_starting_fill: impl Future<Output = ImaginateMaskStartingFill>,
|
||||
improve_faces: impl Future<Output = bool>,
|
||||
tiling: impl Future<Output = bool>,
|
||||
) -> Image<P> {
|
||||
let WasmEditorApi {
|
||||
node_graph_message_sender,
|
||||
imaginate_preferences,
|
||||
..
|
||||
} = editor_api.await;
|
||||
let set_progress = |progress: ImaginateStatus| {
|
||||
controller.set_status(progress);
|
||||
node_graph_message_sender.send(NodeGraphUpdateMessage::ImaginateStatusUpdate);
|
||||
};
|
||||
let host_name = imaginate_preferences.get_host_name();
|
||||
imaginate_maybe_fail(
|
||||
image,
|
||||
host_name,
|
||||
set_progress,
|
||||
&controller,
|
||||
seed,
|
||||
res,
|
||||
samples,
|
||||
sampling_method,
|
||||
prompt_guidance,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
adapt_input_image,
|
||||
image_creativity,
|
||||
masking_layer,
|
||||
inpaint,
|
||||
mask_blur,
|
||||
mask_starting_fill,
|
||||
improve_faces,
|
||||
tiling,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
match err {
|
||||
Error::Terminated => {
|
||||
set_progress(ImaginateStatus::Terminated);
|
||||
}
|
||||
err => {
|
||||
error!("{err}");
|
||||
set_progress(ImaginateStatus::Failed(err.to_string()));
|
||||
}
|
||||
};
|
||||
Image::empty()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "imaginate")]
|
||||
async fn imaginate_maybe_fail<'a, P: Pixel, F: Fn(ImaginateStatus)>(
|
||||
image: Image<P>,
|
||||
host_name: &str,
|
||||
set_progress: F,
|
||||
controller: &ImaginateController,
|
||||
seed: impl Future<Output = f64>,
|
||||
res: impl Future<Output = Option<DVec2>>,
|
||||
samples: impl Future<Output = u32>,
|
||||
sampling_method: impl Future<Output = ImaginateSamplingMethod>,
|
||||
prompt_guidance: impl Future<Output = f64>,
|
||||
prompt: impl Future<Output = String>,
|
||||
negative_prompt: impl Future<Output = String>,
|
||||
adapt_input_image: impl Future<Output = bool>,
|
||||
image_creativity: impl Future<Output = f64>,
|
||||
_masking_layer: impl Future<Output = Option<Vec<u64>>>,
|
||||
_inpaint: impl Future<Output = bool>,
|
||||
_mask_blur: impl Future<Output = f64>,
|
||||
_mask_starting_fill: impl Future<Output = ImaginateMaskStartingFill>,
|
||||
improve_faces: impl Future<Output = bool>,
|
||||
tiling: impl Future<Output = bool>,
|
||||
) -> Result<Image<P>, Error> {
|
||||
set_progress(ImaginateStatus::Beginning);
|
||||
|
||||
let base_url: Url = parse_url(host_name)?;
|
||||
|
||||
let client = new_client()?;
|
||||
|
||||
let sampler_index = sampling_method.await;
|
||||
let sampler_index = sampler_index.api_value();
|
||||
|
||||
let res = res.await.unwrap_or_else(|| {
|
||||
let (width, height) = pick_safe_imaginate_resolution((image.width as _, image.height as _));
|
||||
DVec2::new(width as _, height as _)
|
||||
});
|
||||
let common_request_data = ImaginateCommonImageRequest {
|
||||
prompt: prompt.await,
|
||||
seed: seed.await,
|
||||
steps: samples.await,
|
||||
cfg_scale: prompt_guidance.await,
|
||||
width: res.x,
|
||||
height: res.y,
|
||||
restore_faces: improve_faces.await,
|
||||
tiling: tiling.await,
|
||||
negative_prompt: negative_prompt.await,
|
||||
sampler_index,
|
||||
};
|
||||
let request_builder = if adapt_input_image.await {
|
||||
let base64_data = image_to_base64(image)?;
|
||||
let request_data = ImaginateImageToImageRequest {
|
||||
common: common_request_data,
|
||||
override_settings: Default::default(),
|
||||
|
||||
init_images: vec![base64_data],
|
||||
denoising_strength: image_creativity.await * 0.01,
|
||||
mask: None,
|
||||
};
|
||||
let url = join_url(&base_url, SDAPI_IMAGE_TO_IMAGE)?;
|
||||
client.post(url).json(&request_data)
|
||||
} else {
|
||||
let request_data = ImaginateTextToImageRequest {
|
||||
common: common_request_data,
|
||||
override_settings: Default::default(),
|
||||
};
|
||||
let url = join_url(&base_url, SDAPI_TEXT_TO_IMAGE)?;
|
||||
client.post(url).json(&request_data)
|
||||
};
|
||||
|
||||
let request = request_builder.header("Accept", "*/*").build().map_err(Error::RequestBuild)?;
|
||||
|
||||
let (response_future, abort_handle) = futures::future::abortable(client.execute(request));
|
||||
controller.set_termination_handle(Box::new(ImaginateFutureAbortHandle(abort_handle)));
|
||||
|
||||
let progress_url = join_url(&base_url, SDAPI_PROGRESS)?;
|
||||
|
||||
futures::pin_mut!(response_future);
|
||||
|
||||
let response = loop {
|
||||
let progress_request = new_get_request(&client, progress_url.clone())?;
|
||||
let progress_response_future = client.execute(progress_request).and_then(|response| response.json());
|
||||
|
||||
futures::pin_mut!(progress_response_future);
|
||||
|
||||
response_future = match futures::future::select(response_future, progress_response_future).await {
|
||||
Either::Left((response, _)) => break response,
|
||||
Either::Right((progress, response_future)) => {
|
||||
if let Ok(ProgressResponse { progress }) = progress {
|
||||
set_progress(ImaginateStatus::Generating(progress * 100.));
|
||||
}
|
||||
response_future
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
let response = match response {
|
||||
Ok(response) => response.and_then(reqwest::Response::error_for_status).map_err(Error::Request)?,
|
||||
Err(_aborted) => {
|
||||
set_progress(ImaginateStatus::Terminating);
|
||||
let url = join_url(&base_url, SDAPI_TERMINATE)?;
|
||||
let request = client.post(url).build().map_err(Error::RequestBuild)?;
|
||||
// The user probably doesn't really care if the server side was really aborted or if there was an network error.
|
||||
// So we fool them that the request was terminated if the termination request in reality failed.
|
||||
let _ = client.execute(request).await.and_then(reqwest::Response::error_for_status).map_err(Error::TerminationFailed)?;
|
||||
return Err(Error::Terminated);
|
||||
}
|
||||
};
|
||||
|
||||
set_progress(ImaginateStatus::Uploading);
|
||||
|
||||
let ImageResponse { images } = response.json().await.map_err(Error::ResponseFormat)?;
|
||||
|
||||
let result = images.into_iter().next().ok_or(Error::NoImage).and_then(base64_to_image)?;
|
||||
|
||||
set_progress(ImaginateStatus::ReadyDone);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn image_to_base64<P: Pixel>(image: Image<P>) -> Result<String, Error> {
|
||||
use base64::prelude::*;
|
||||
|
||||
let Image { width, height, data } = image;
|
||||
|
||||
fn cast_with_f32<S: Pixel, D: image::Pixel<Subpixel = f32>>(data: Vec<S>, width: u32, height: u32) -> Result<DynamicImage, Error>
|
||||
where
|
||||
DynamicImage: From<ImageBuffer<D, Vec<f32>>>,
|
||||
{
|
||||
ImageBuffer::<D, Vec<f32>>::from_raw(width, height, bytemuck::cast_vec(data))
|
||||
.ok_or(Error::InconsistentImageSize)
|
||||
.map(Into::into)
|
||||
}
|
||||
|
||||
let image: DynamicImage = match TypeId::of::<P>() {
|
||||
id if id == TypeId::of::<Color>() => cast_with_f32::<_, image::Rgba<f32>>(data, width, height)?
|
||||
// we need to do this cast, because png does not support rgba32f
|
||||
.to_rgba16().into(),
|
||||
id if id == TypeId::of::<Luma>() => cast_with_f32::<_, image::Luma<f32>>(data, width, height)?
|
||||
// we need to do this cast, because png does not support luma32f
|
||||
.to_luma16().into(),
|
||||
_ => return Err(Error::UnsupportedPixelType(core::any::type_name::<P>())),
|
||||
};
|
||||
|
||||
let mut png_data = std::io::Cursor::new(vec![]);
|
||||
image.write_to(&mut png_data, ImageOutputFormat::Png).map_err(Error::ImageEncode)?;
|
||||
Ok(BASE64_STANDARD.encode(png_data.into_inner()))
|
||||
}
|
||||
|
||||
fn base64_to_image<D: AsRef<[u8]>, P: Pixel>(base64_data: D) -> Result<Image<P>, Error> {
|
||||
use base64::prelude::*;
|
||||
|
||||
let png_data = BASE64_STANDARD.decode(base64_data).map_err(Error::Base64Decode)?;
|
||||
let dyn_image = image::load_from_memory_with_format(&png_data, image::ImageFormat::Png).map_err(Error::ImageDecode)?;
|
||||
let (width, height) = (dyn_image.width(), dyn_image.height());
|
||||
|
||||
let result_data: Vec<P> = match TypeId::of::<P>() {
|
||||
id if id == TypeId::of::<Color>() => bytemuck::cast_vec(dyn_image.into_rgba32f().into_raw()),
|
||||
id if id == TypeId::of::<Luma>() => bytemuck::cast_vec(dyn_image.to_luma32f().into_raw()),
|
||||
_ => return Err(Error::UnsupportedPixelType(core::any::type_name::<P>())),
|
||||
};
|
||||
|
||||
Ok(Image { data: result_data, width, height })
|
||||
}
|
||||
|
||||
pub fn pick_safe_imaginate_resolution((width, height): (f64, f64)) -> (u64, u64) {
|
||||
const MAX_RESOLUTION: u64 = 1000 * 1000;
|
||||
|
||||
// this is the maximum width/height that can be obtained
|
||||
const MAX_DIMENSION: u64 = (MAX_RESOLUTION / 64) & !63;
|
||||
|
||||
// round the resolution to the nearest multiple of 64
|
||||
let [width, height] = [width, height].map(|c| (c.round().clamp(0., MAX_DIMENSION as _) as u64 + 32).max(64) & !63);
|
||||
let resolution = width * height;
|
||||
|
||||
if resolution > MAX_RESOLUTION {
|
||||
// scale down the image, so it is smaller than MAX_RESOLUTION
|
||||
let scale = (MAX_RESOLUTION as f64 / resolution as f64).sqrt();
|
||||
let [width, height] = [width, height].map(|c| c as f64 * scale);
|
||||
|
||||
if width < 64.0 {
|
||||
// the image is extremely wide
|
||||
(64, MAX_DIMENSION)
|
||||
} else if height < 64.0 {
|
||||
// the image is extremely high
|
||||
(MAX_DIMENSION, 64)
|
||||
} else {
|
||||
// round down to a multiple of 64, so that the resolution still is smaller than MAX_RESOLUTION
|
||||
let [width, height] = [width, height].map(|c| c as u64 & !63);
|
||||
(width, height)
|
||||
}
|
||||
} else {
|
||||
(width, height)
|
||||
}
|
||||
}
|
|
@ -25,3 +25,5 @@ pub mod brush;
|
|||
|
||||
#[cfg(feature = "wasm")]
|
||||
pub mod wasm_application_io;
|
||||
|
||||
pub mod imaginate;
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
use dyn_any::{DynAny, StaticType};
|
||||
use glam::{DAffine2, DVec2};
|
||||
use graph_craft::imaginate_input::{ImaginateController, ImaginateMaskStartingFill, ImaginateSamplingMethod};
|
||||
use graph_craft::proto::DynFuture;
|
||||
use graphene_core::raster::{Alpha, BlendMode, BlendNode, Image, ImageFrame, Linear, LinearChannel, Luminance, Pixel, RGBMut, Raster, RasterMut, RedGreenBlue, Sample};
|
||||
use graphene_core::transform::Transform;
|
||||
|
||||
use crate::wasm_application_io::WasmEditorApi;
|
||||
use graphene_core::raster::bbox::{AxisAlignedBbox, Bbox};
|
||||
use graphene_core::value::CopiedNode;
|
||||
use graphene_core::{Color, Node};
|
||||
|
@ -414,19 +417,74 @@ fn empty_image<_P: Pixel>(transform: DAffine2, color: _P) -> ImageFrame<_P> {
|
|||
ImageFrame { image, transform }
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ImaginateNode<P, E> {
|
||||
cached: E,
|
||||
_p: PhantomData<P>,
|
||||
macro_rules! generate_imaginate_node {
|
||||
($($val:ident: $t:ident: $o:ty,)*) => {
|
||||
pub struct ImaginateNode<P: Pixel, E, C, $($t,)*> {
|
||||
editor_api: E,
|
||||
controller: C,
|
||||
$($val: $t,)*
|
||||
cache: std::sync::Arc<std::sync::Mutex<Image<P>>>,
|
||||
}
|
||||
|
||||
impl<'e, P: Pixel, E, C, $($t,)*> ImaginateNode<P, E, C, $($t,)*>
|
||||
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
|
||||
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, WasmEditorApi<'e>>>,
|
||||
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
|
||||
{
|
||||
pub fn new(editor_api: E, controller: C, $($val: $t,)* cache: std::sync::Arc<std::sync::Mutex<Image<P>>>) -> Self {
|
||||
Self { editor_api, controller, $($val,)* cache }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'i, 'e: 'i, P: Pixel + 'i, E: 'i, C: 'i, $($t: 'i,)*> Node<'i, ImageFrame<P>> for ImaginateNode<P, E, C, $($t,)*>
|
||||
where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
|
||||
E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, WasmEditorApi<'e>>>,
|
||||
C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
|
||||
{
|
||||
type Output = DynFuture<'i, ImageFrame<P>>;
|
||||
|
||||
fn eval(&'i self, frame: ImageFrame<P>) -> Self::Output {
|
||||
let controller = self.controller.eval(());
|
||||
$(let $val = self.$val.eval(());)*
|
||||
Box::pin(async move {
|
||||
let controller: std::pin::Pin<Box<dyn std::future::Future<Output = ImaginateController>>> = controller;
|
||||
let controller: ImaginateController = controller.await;
|
||||
if controller.take_regenerate_trigger() {
|
||||
let editor_api = self.editor_api.eval(());
|
||||
let image = super::imaginate::imaginate(frame.image, editor_api, controller, $($val,)*).await;
|
||||
self.cache.lock().unwrap().clone_from(&image);
|
||||
return ImageFrame {
|
||||
image,
|
||||
..frame
|
||||
}
|
||||
}
|
||||
let image = self.cache.lock().unwrap().clone();
|
||||
ImageFrame {
|
||||
image,
|
||||
..frame
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(ImaginateNode<_P>)]
|
||||
fn imaginate<_P: Pixel>(image_frame: ImageFrame<_P>, cached: Option<std::sync::Arc<graphene_core::raster::Image<_P>>>) -> ImageFrame<_P> {
|
||||
let cached_image = cached.map(|mut x| std::sync::Arc::make_mut(&mut x).clone()).unwrap_or(image_frame.image);
|
||||
ImageFrame {
|
||||
image: cached_image,
|
||||
transform: image_frame.transform,
|
||||
}
|
||||
generate_imaginate_node! {
|
||||
seed: Seed: f64,
|
||||
res: Res: Option<DVec2>,
|
||||
samples: Samples: u32,
|
||||
sampling_method: SamplingMethod: ImaginateSamplingMethod,
|
||||
prompt_guidance: PromptGuidance: f64,
|
||||
prompt: Prompt: String,
|
||||
negative_prompt: NegativePrompt: String,
|
||||
adapt_input_image: AdaptInputImage: bool,
|
||||
image_creativity: ImageCreativity: f64,
|
||||
masking_layer: MaskingLayer: Option<Vec<u64>>,
|
||||
inpaint: Inpaint: bool,
|
||||
mask_blur: MaskBlur: f64,
|
||||
mask_starting_fill: MaskStartingFill: ImaginateMaskStartingFill,
|
||||
improve_faces: ImproveFaces: bool,
|
||||
tiling: Tiling: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue