Remove dead code for Imaginate

This commit is contained in:
Keavon Chambers 2025-06-26 18:33:00 -07:00
parent 1a4d7aa23c
commit 1875779b0a
32 changed files with 23 additions and 2022 deletions

View file

@ -193,9 +193,7 @@ pub enum ApplicationError {
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum NodeGraphUpdateMessage {
// ImaginateStatusUpdate,
}
pub enum NodeGraphUpdateMessage {}
pub trait NodeGraphUpdateSender {
fn send(&self, message: NodeGraphUpdateMessage);
@ -208,7 +206,6 @@ impl<T: NodeGraphUpdateSender> NodeGraphUpdateSender for std::sync::Mutex<T> {
}
pub trait GetEditorPreferences {
// fn hostname(&self) -> &str;
fn use_vello(&self) -> bool;
}
@ -250,10 +247,6 @@ impl NodeGraphUpdateSender for Logger {
struct DummyPreferences;
impl GetEditorPreferences for DummyPreferences {
// fn hostname(&self) -> &str {
// "dummy_endpoint"
// }
fn use_vello(&self) -> bool {
false
}

View file

@ -9,7 +9,6 @@ use crate::uuid::{NodeId, generate_uuid};
use crate::vector::style::{Fill, Stroke, StrokeAlign, ViewMode};
use crate::vector::{PointId, VectorDataTable};
use crate::{Artboard, ArtboardGroupTable, Color, GraphicElement, GraphicGroupTable};
use base64::Engine;
use bezier_rs::Subpath;
use dyn_any::DynAny;
use glam::{DAffine2, DMat2, DVec2};
@ -1148,6 +1147,8 @@ impl GraphicElementRendered for RasterDataTable<CPU> {
}
let base64_string = image.base64_string.clone().unwrap_or_else(|| {
use base64::Engine;
let output = image.to_png();
let preamble = "data:image/png;base64,";
let mut base64_string = String::with_capacity(preamble.len() + output.len() * 4);

View file

@ -248,11 +248,6 @@ tagged_value! {
ReferencePoint(graphene_core::transform::ReferencePoint),
CentroidType(graphene_core::vector::misc::CentroidType),
BooleanOperation(graphene_core::vector::misc::BooleanOperation),
// ImaginateCache(ImaginateCache),
// ImaginateSamplingMethod(ImaginateSamplingMethod),
// ImaginateMaskStartingFill(ImaginateMaskStartingFill),
// ImaginateController(ImaginateController),
}
impl TaggedValue {

View file

@ -1,279 +0,0 @@
use dyn_any::DynAny;
use graphene_core::Color;
use std::borrow::Cow;
use std::fmt::Debug;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Default, Debug, Clone, DynAny, specta::Type, serde::Serialize, serde::Deserialize)]
pub struct ImaginateCache(Arc<Mutex<graphene_core::raster::Image<Color>>>);
impl ImaginateCache {
pub fn into_inner(self) -> Arc<Mutex<graphene_core::raster::Image<Color>>> {
self.0
}
}
impl std::cmp::PartialEq for ImaginateCache {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl core::hash::Hash for ImaginateCache {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
let _ = self.0.try_lock().map(|g| g.hash(state)).map_err(|_| "error".hash(state));
}
}
pub trait ImaginateTerminationHandle: Debug + Send + 'static {
fn terminate(&self);
}
#[derive(Default, Debug, specta::Type, serde::Serialize, serde::Deserialize)]
struct InternalImaginateControl {
#[serde(skip)]
status: Mutex<ImaginateStatus>,
trigger_regenerate: AtomicBool,
#[serde(skip)]
#[specta(skip)]
termination_sender: Mutex<Option<Box<dyn ImaginateTerminationHandle>>>,
}
#[derive(Debug, Default, Clone, DynAny, specta::Type, serde::Serialize, serde::Deserialize)]
pub struct ImaginateController(Arc<InternalImaginateControl>);
impl ImaginateController {
pub fn get_status(&self) -> ImaginateStatus {
self.0.status.try_lock().as_deref().cloned().unwrap_or_default()
}
pub fn set_status(&self, status: ImaginateStatus) {
if let Ok(mut lock) = self.0.status.try_lock() {
*lock = status
}
}
pub fn take_regenerate_trigger(&self) -> bool {
self.0.trigger_regenerate.swap(false, Ordering::SeqCst)
}
pub fn trigger_regenerate(&self) {
self.0.trigger_regenerate.store(true, Ordering::SeqCst)
}
pub fn request_termination(&self) {
if let Some(handle) = self.0.termination_sender.try_lock().ok().and_then(|mut lock| lock.take()) {
handle.terminate()
}
}
pub fn set_termination_handle<H: ImaginateTerminationHandle>(&self, handle: Box<H>) {
if let Ok(mut lock) = self.0.termination_sender.try_lock() {
*lock = Some(handle)
}
}
}
impl std::cmp::PartialEq for ImaginateController {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl core::hash::Hash for ImaginateController {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
core::ptr::hash(Arc::as_ptr(&self.0), state)
}
}
#[derive(Default, Debug, Clone, PartialEq, DynAny, specta::Type, serde::Serialize, serde::Deserialize)]
pub enum ImaginateStatus {
#[default]
Ready,
ReadyDone,
Beginning,
Uploading,
Generating(f64),
Terminating,
Terminated,
Failed(String),
}
impl ImaginateStatus {
pub fn to_text(&self) -> Cow<'static, str> {
match self {
Self::Ready => Cow::Borrowed("Ready"),
Self::ReadyDone => Cow::Borrowed("Done"),
Self::Beginning => Cow::Borrowed("Beginning…"),
Self::Uploading => Cow::Borrowed("Downloading Image…"),
Self::Generating(percent) => Cow::Owned(format!("Generating {percent:.0}%")),
Self::Terminating => Cow::Owned("Terminating…".to_string()),
Self::Terminated => Cow::Owned("Terminated".to_string()),
Self::Failed(err) => Cow::Owned(format!("Failed: {err}")),
}
}
}
#[allow(clippy::derived_hash_with_manual_eq)]
impl core::hash::Hash for ImaginateStatus {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
core::mem::discriminant(self).hash(state);
match self {
Self::Ready | Self::ReadyDone | Self::Beginning | Self::Uploading | Self::Terminating | Self::Terminated => (),
Self::Generating(f) => f.to_bits().hash(state),
Self::Failed(err) => err.hash(state),
}
}
}
#[derive(PartialEq, Eq, Clone, Default, Debug)]
pub enum ImaginateServerStatus {
#[default]
Unknown,
Checking,
Connected,
Failed(String),
Unavailable,
}
impl ImaginateServerStatus {
pub fn to_text(&self) -> Cow<'static, str> {
match self {
Self::Unknown | Self::Checking => Cow::Borrowed("Checking..."),
Self::Connected => Cow::Borrowed("Connected"),
Self::Failed(err) => Cow::Owned(err.clone()),
Self::Unavailable => Cow::Borrowed("Unavailable"),
}
}
}
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, specta::Type, Hash, serde::Serialize, serde::Deserialize)]
pub enum ImaginateMaskPaintMode {
#[default]
Inpaint,
Outpaint,
}
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny, specta::Type, Hash, serde::Serialize, serde::Deserialize)]
pub enum ImaginateMaskStartingFill {
#[default]
Fill,
Original,
LatentNoise,
LatentNothing,
}
impl ImaginateMaskStartingFill {
pub fn list() -> [ImaginateMaskStartingFill; 4] {
[
ImaginateMaskStartingFill::Fill,
ImaginateMaskStartingFill::Original,
ImaginateMaskStartingFill::LatentNoise,
ImaginateMaskStartingFill::LatentNothing,
]
}
}
impl std::fmt::Display for ImaginateMaskStartingFill {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ImaginateMaskStartingFill::Fill => write!(f, "Smeared Surroundings"),
ImaginateMaskStartingFill::Original => write!(f, "Original Input Image"),
ImaginateMaskStartingFill::LatentNoise => write!(f, "Randomness (Latent Noise)"),
ImaginateMaskStartingFill::LatentNothing => write!(f, "Neutral (Latent Nothing)"),
}
}
}
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, DynAny, specta::Type, Hash, serde::Serialize, serde::Deserialize)]
pub enum ImaginateSamplingMethod {
#[default]
EulerA,
Euler,
LMS,
Heun,
DPM2,
DPM2A,
DPMPlusPlus2sA,
DPMPlusPlus2m,
DPMFast,
DPMAdaptive,
LMSKarras,
DPM2Karras,
DPM2AKarras,
DPMPlusPlus2sAKarras,
DPMPlusPlus2mKarras,
DDIM,
PLMS,
}
impl ImaginateSamplingMethod {
pub fn api_value(&self) -> &str {
match self {
ImaginateSamplingMethod::EulerA => "Euler a",
ImaginateSamplingMethod::Euler => "Euler",
ImaginateSamplingMethod::LMS => "LMS",
ImaginateSamplingMethod::Heun => "Heun",
ImaginateSamplingMethod::DPM2 => "DPM2",
ImaginateSamplingMethod::DPM2A => "DPM2 a",
ImaginateSamplingMethod::DPMPlusPlus2sA => "DPM++ 2S a",
ImaginateSamplingMethod::DPMPlusPlus2m => "DPM++ 2M",
ImaginateSamplingMethod::DPMFast => "DPM fast",
ImaginateSamplingMethod::DPMAdaptive => "DPM adaptive",
ImaginateSamplingMethod::LMSKarras => "LMS Karras",
ImaginateSamplingMethod::DPM2Karras => "DPM2 Karras",
ImaginateSamplingMethod::DPM2AKarras => "DPM2 a Karras",
ImaginateSamplingMethod::DPMPlusPlus2sAKarras => "DPM++ 2S a Karras",
ImaginateSamplingMethod::DPMPlusPlus2mKarras => "DPM++ 2M Karras",
ImaginateSamplingMethod::DDIM => "DDIM",
ImaginateSamplingMethod::PLMS => "PLMS",
}
}
pub fn list() -> [ImaginateSamplingMethod; 17] {
[
ImaginateSamplingMethod::EulerA,
ImaginateSamplingMethod::Euler,
ImaginateSamplingMethod::LMS,
ImaginateSamplingMethod::Heun,
ImaginateSamplingMethod::DPM2,
ImaginateSamplingMethod::DPM2A,
ImaginateSamplingMethod::DPMPlusPlus2sA,
ImaginateSamplingMethod::DPMPlusPlus2m,
ImaginateSamplingMethod::DPMFast,
ImaginateSamplingMethod::DPMAdaptive,
ImaginateSamplingMethod::LMSKarras,
ImaginateSamplingMethod::DPM2Karras,
ImaginateSamplingMethod::DPM2AKarras,
ImaginateSamplingMethod::DPMPlusPlus2sAKarras,
ImaginateSamplingMethod::DPMPlusPlus2mKarras,
ImaginateSamplingMethod::DDIM,
ImaginateSamplingMethod::PLMS,
]
}
}
impl std::fmt::Display for ImaginateSamplingMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ImaginateSamplingMethod::EulerA => write!(f, "Euler A (Recommended)"),
ImaginateSamplingMethod::Euler => write!(f, "Euler"),
ImaginateSamplingMethod::LMS => write!(f, "LMS"),
ImaginateSamplingMethod::Heun => write!(f, "Heun"),
ImaginateSamplingMethod::DPM2 => write!(f, "DPM2"),
ImaginateSamplingMethod::DPM2A => write!(f, "DPM2 A"),
ImaginateSamplingMethod::DPMPlusPlus2sA => write!(f, "DPM++ 2S a"),
ImaginateSamplingMethod::DPMPlusPlus2m => write!(f, "DPM++ 2M"),
ImaginateSamplingMethod::DPMFast => write!(f, "DPM Fast"),
ImaginateSamplingMethod::DPMAdaptive => write!(f, "DPM Adaptive"),
ImaginateSamplingMethod::LMSKarras => write!(f, "LMS Karras"),
ImaginateSamplingMethod::DPM2Karras => write!(f, "DPM2 Karras"),
ImaginateSamplingMethod::DPM2AKarras => write!(f, "DPM2 A Karras"),
ImaginateSamplingMethod::DPMPlusPlus2sAKarras => write!(f, "DPM++ 2S a Karras"),
ImaginateSamplingMethod::DPMPlusPlus2mKarras => write!(f, "DPM++ 2M Karras"),
ImaginateSamplingMethod::DDIM => write!(f, "DDIM"),
ImaginateSamplingMethod::PLMS => write!(f, "PLMS"),
}
}
}

View file

@ -317,14 +317,10 @@ pub type WasmSurfaceHandleFrame = graphene_application_io::SurfaceHandleFrame<wg
#[derive(Clone, Debug, PartialEq, Hash, specta::Type, serde::Serialize, serde::Deserialize)]
pub struct EditorPreferences {
// pub imaginate_hostname: String,
pub use_vello: bool,
}
impl graphene_application_io::GetEditorPreferences for EditorPreferences {
// fn hostname(&self) -> &str {
// &self.imaginate_hostname
// }
fn use_vello(&self) -> bool {
self.use_vello
}
@ -333,7 +329,6 @@ impl graphene_application_io::GetEditorPreferences for EditorPreferences {
impl Default for EditorPreferences {
fn default() -> Self {
Self {
// imaginate_hostname: "http://localhost:7860/".into(),
#[cfg(target_arch = "wasm32")]
use_vello: false,
#[cfg(not(target_arch = "wasm32"))]

View file

@ -7,11 +7,16 @@ authors = ["Graphite Authors <contact@graphite.rs>"]
license = "MIT OR Apache-2.0"
[features]
default = ["wasm", "imaginate"]
default = ["wasm"]
gpu = []
wgpu = ["gpu", "graph-craft/wgpu", "graphene-application-io/wgpu"]
wasm = ["wasm-bindgen", "web-sys", "graphene-application-io/wasm"]
imaginate = ["image/png", "base64", "web-sys", "wasm-bindgen-futures"]
wasm = [
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"graphene-application-io/wasm",
"image/png",
]
image-compare = []
vello = ["dep:vello", "gpu", "graphene-core/vello"]
resvg = []
@ -39,9 +44,9 @@ rand_chacha = { workspace = true }
rand = { workspace = true }
bytemuck = { workspace = true }
image = { workspace = true }
base64 = { workspace = true }
# Optional workspace dependencies
base64 = { workspace = true, optional = true }
wasm-bindgen = { workspace = true, optional = true }
wasm-bindgen-futures = { workspace = true, optional = true }
tokio = { workspace = true, optional = true }
@ -58,7 +63,7 @@ web-sys = { workspace = true, optional = true, features = [
"ImageBitmapRenderingContext",
] }
# Optional dependencies
# Required dependencies
ndarray = "0.16.1"
[dev-dependencies]

View file

@ -1,526 +0,0 @@
use crate::wasm_application_io::WasmEditorApi;
use core::any::TypeId;
use core::future::Future;
use futures::TryFutureExt;
use futures::future::Either;
use glam::{DVec2, U64Vec2};
use graph_craft::imaginate_input::{ImaginateController, ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateServerStatus, ImaginateStatus, ImaginateTerminationHandle};
use graph_craft::wasm_application_io::EditorPreferences;
use graphene_core::application_io::NodeGraphUpdateMessage;
use graphene_core::raster::{Color, Image, Luma, Pixel};
use image::{DynamicImage, ImageBuffer, ImageFormat};
use reqwest::Url;
const PROGRESS_EVERY_N_STEPS: u32 = 5;
const SDAPI_TEXT_TO_IMAGE: &str = "sdapi/v1/txt2img";
const SDAPI_IMAGE_TO_IMAGE: &str = "sdapi/v1/img2img";
const SDAPI_PROGRESS: &str = "sdapi/v1/progress?skip_current_image=true";
const SDAPI_TERMINATE: &str = "sdapi/v1/interrupt";
fn new_client() -> Result<reqwest::Client, Error> {
reqwest::ClientBuilder::new().build().map_err(Error::ClientBuild)
}
fn parse_url(url: &str) -> Result<Url, Error> {
url.try_into().map_err(|err| Error::UrlParse { text: url.into(), err })
}
fn join_url(base_url: &Url, path: &str) -> Result<Url, Error> {
base_url.join(path).map_err(|err| Error::UrlParse { text: base_url.to_string(), err })
}
fn new_get_request<U: reqwest::IntoUrl>(client: &reqwest::Client, url: U) -> Result<reqwest::Request, Error> {
client.get(url).header("Accept", "*/*").build().map_err(Error::RequestBuild)
}
pub struct ImaginatePersistentData {
pending_server_check: Option<futures::channel::oneshot::Receiver<reqwest::Result<reqwest::Response>>>,
host_name: Url,
client: Option<reqwest::Client>,
server_status: ImaginateServerStatus,
}
impl core::fmt::Debug for ImaginatePersistentData {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct(core::any::type_name::<Self>())
.field("pending_server_check", &self.pending_server_check.is_some())
.field("host_name", &self.host_name)
.field("status", &self.server_status)
.finish()
}
}
impl Default for ImaginatePersistentData {
fn default() -> Self {
let server_status = ImaginateServerStatus::default();
#[cfg(not(miri))]
let mut server_status = server_status;
#[cfg(not(miri))]
let client = new_client().map_err(|err| server_status = ImaginateServerStatus::Failed(err.to_string())).ok();
#[cfg(miri)]
let client = None;
let EditorPreferences { imaginate_hostname: host_name, .. } = Default::default();
Self {
pending_server_check: None,
host_name: parse_url(&host_name).unwrap(),
client,
server_status,
}
}
}
type ImaginateFuture = core::pin::Pin<Box<dyn Future<Output = ()> + 'static>>;
impl ImaginatePersistentData {
pub fn set_host_name(&mut self, name: &str) {
match parse_url(name) {
Ok(url) => self.host_name = url,
Err(err) => self.server_status = ImaginateServerStatus::Failed(err.to_string()),
}
}
fn initiate_server_check_maybe_fail(&mut self) -> Result<Option<ImaginateFuture>, Error> {
use futures::future::FutureExt;
let Some(client) = &self.client else {
return Ok(None);
};
if self.pending_server_check.is_some() {
return Ok(None);
}
self.server_status = ImaginateServerStatus::Checking;
let url = join_url(&self.host_name, SDAPI_PROGRESS)?;
let request = new_get_request(client, url)?;
let (send, recv) = futures::channel::oneshot::channel();
let response_future = client.execute(request).map(move |r| {
let _ = send.send(r);
});
self.pending_server_check = Some(recv);
Ok(Some(Box::pin(response_future)))
}
pub fn initiate_server_check(&mut self) -> Option<ImaginateFuture> {
match self.initiate_server_check_maybe_fail() {
Ok(f) => f,
Err(err) => {
self.server_status = ImaginateServerStatus::Failed(err.to_string());
None
}
}
}
pub fn poll_server_check(&mut self) {
if let Some(mut check) = self.pending_server_check.take() {
self.server_status = match check.try_recv().map(|r| r.map(|r| r.and_then(reqwest::Response::error_for_status))) {
Ok(Some(Ok(_response))) => ImaginateServerStatus::Connected,
Ok(Some(Err(_))) | Err(_) => ImaginateServerStatus::Unavailable,
Ok(None) => {
self.pending_server_check = Some(check);
ImaginateServerStatus::Checking
}
}
}
}
pub fn server_status(&self) -> &ImaginateServerStatus {
&self.server_status
}
pub fn is_checking(&self) -> bool {
matches!(self.server_status, ImaginateServerStatus::Checking)
}
}
#[derive(Debug)]
struct ImaginateFutureAbortHandle(futures::future::AbortHandle);
impl ImaginateTerminationHandle for ImaginateFutureAbortHandle {
fn terminate(&self) {
self.0.abort()
}
}
#[derive(Debug)]
enum Error {
UrlParse { text: String, err: <&'static str as TryInto<Url>>::Error },
ClientBuild(reqwest::Error),
RequestBuild(reqwest::Error),
Request(reqwest::Error),
ResponseFormat(reqwest::Error),
NoImage,
Base64Decode(base64::DecodeError),
ImageDecode(image::error::ImageError),
ImageEncode(image::error::ImageError),
UnsupportedPixelType(&'static str),
InconsistentImageSize,
Terminated,
TerminationFailed(reqwest::Error),
}
impl core::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
match self {
Self::UrlParse { text, err } => write!(f, "invalid url '{text}' ({err})"),
Self::ClientBuild(err) => write!(f, "failed to create a reqwest client ({err})"),
Self::RequestBuild(err) => write!(f, "failed to create a reqwest request ({err})"),
Self::Request(err) => write!(f, "request failed ({err})"),
Self::ResponseFormat(err) => write!(f, "got an invalid API response ({err})"),
Self::NoImage => write!(f, "got an empty API response"),
Self::Base64Decode(err) => write!(f, "failed to decode base64 encoded image ({err})"),
Self::ImageDecode(err) => write!(f, "failed to decode png image ({err})"),
Self::ImageEncode(err) => write!(f, "failed to encode png image ({err})"),
Self::UnsupportedPixelType(ty) => write!(f, "pixel type `{ty}` not supported for imaginate images"),
Self::InconsistentImageSize => write!(f, "image width and height do not match the image byte size"),
Self::Terminated => write!(f, "imaginate request was terminated by the user"),
Self::TerminationFailed(err) => write!(f, "termination failed ({err})"),
}
}
}
impl std::error::Error for Error {}
#[derive(Default, Debug, Clone, serde::Deserialize)]
struct ImageResponse {
images: Vec<String>,
}
#[derive(Default, Debug, Clone, serde::Deserialize)]
struct ProgressResponse {
progress: f64,
}
#[derive(Debug, Clone, Copy, serde::Serialize)]
struct ImaginateTextToImageRequestOverrideSettings {
show_progress_every_n_steps: u32,
}
impl Default for ImaginateTextToImageRequestOverrideSettings {
fn default() -> Self {
Self {
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
}
}
}
#[derive(Debug, Clone, Copy, serde::Serialize)]
struct ImaginateImageToImageRequestOverrideSettings {
show_progress_every_n_steps: u32,
img2img_fix_steps: bool,
}
impl Default for ImaginateImageToImageRequestOverrideSettings {
fn default() -> Self {
Self {
show_progress_every_n_steps: PROGRESS_EVERY_N_STEPS,
img2img_fix_steps: true,
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
struct ImaginateTextToImageRequest<'a> {
#[serde(flatten)]
common: ImaginateCommonImageRequest<'a>,
override_settings: ImaginateTextToImageRequestOverrideSettings,
}
#[derive(Debug, Clone, serde::Serialize)]
struct ImaginateMask {
mask: String,
mask_blur: String,
inpainting_fill: u32,
inpaint_full_res: bool,
inpainting_mask_invert: u32,
}
#[derive(Debug, Clone, serde::Serialize)]
struct ImaginateImageToImageRequest<'a> {
#[serde(flatten)]
common: ImaginateCommonImageRequest<'a>,
override_settings: ImaginateImageToImageRequestOverrideSettings,
init_images: Vec<String>,
denoising_strength: f64,
#[serde(flatten)]
mask: Option<ImaginateMask>,
}
#[derive(Debug, Clone, serde::Serialize)]
struct ImaginateCommonImageRequest<'a> {
prompt: String,
seed: f64,
steps: u32,
cfg_scale: f64,
width: f64,
height: f64,
restore_faces: bool,
tiling: bool,
negative_prompt: String,
sampler_index: &'a str,
}
#[cfg(all(feature = "imaginate", feature = "serde"))]
#[allow(clippy::too_many_arguments)]
pub async fn imaginate<'a, P: Pixel>(
image: Image<P>,
editor_api: impl Future<Output = &'a WasmEditorApi>,
controller: ImaginateController,
seed: impl Future<Output = f64>,
res: impl Future<Output = Option<DVec2>>,
samples: impl Future<Output = u32>,
sampling_method: impl Future<Output = ImaginateSamplingMethod>,
prompt_guidance: impl Future<Output = f64>,
prompt: impl Future<Output = String>,
negative_prompt: impl Future<Output = String>,
adapt_input_image: impl Future<Output = bool>,
image_creativity: impl Future<Output = f64>,
inpaint: impl Future<Output = bool>,
mask_blur: impl Future<Output = f64>,
mask_starting_fill: impl Future<Output = ImaginateMaskStartingFill>,
improve_faces: impl Future<Output = bool>,
tiling: impl Future<Output = bool>,
) -> Image<P> {
let WasmEditorApi {
node_graph_message_sender,
editor_preferences,
..
} = editor_api.await;
let set_progress = |progress: ImaginateStatus| {
controller.set_status(progress);
node_graph_message_sender.send(NodeGraphUpdateMessage::ImaginateStatusUpdate);
};
let host_name = editor_preferences.hostname();
imaginate_maybe_fail(
image,
host_name,
set_progress,
&controller,
seed,
res,
samples,
sampling_method,
prompt_guidance,
prompt,
negative_prompt,
adapt_input_image,
image_creativity,
inpaint,
mask_blur,
mask_starting_fill,
improve_faces,
tiling,
)
.await
.unwrap_or_else(|err| {
match err {
Error::Terminated => {
set_progress(ImaginateStatus::Terminated);
}
err => {
error!("{err}");
set_progress(ImaginateStatus::Failed(err.to_string()));
}
};
Image::default()
})
}
#[cfg(all(feature = "imaginate", feature = "serde"))]
#[allow(clippy::too_many_arguments)]
async fn imaginate_maybe_fail<P: Pixel, F: Fn(ImaginateStatus)>(
image: Image<P>,
host_name: &str,
set_progress: F,
controller: &ImaginateController,
seed: impl Future<Output = f64>,
res: impl Future<Output = Option<DVec2>>,
samples: impl Future<Output = u32>,
sampling_method: impl Future<Output = ImaginateSamplingMethod>,
prompt_guidance: impl Future<Output = f64>,
prompt: impl Future<Output = String>,
negative_prompt: impl Future<Output = String>,
adapt_input_image: impl Future<Output = bool>,
image_creativity: impl Future<Output = f64>,
_inpaint: impl Future<Output = bool>,
_mask_blur: impl Future<Output = f64>,
_mask_starting_fill: impl Future<Output = ImaginateMaskStartingFill>,
improve_faces: impl Future<Output = bool>,
tiling: impl Future<Output = bool>,
) -> Result<Image<P>, Error> {
set_progress(ImaginateStatus::Beginning);
let base_url: Url = parse_url(host_name)?;
let client = new_client()?;
let sampler_index = sampling_method.await;
let sampler_index = sampler_index.api_value();
let res = res.await.unwrap_or_else(|| {
let (width, height) = pick_safe_imaginate_resolution((image.width as _, image.height as _));
DVec2::new(width as _, height as _)
});
let common_request_data = ImaginateCommonImageRequest {
prompt: prompt.await,
seed: seed.await,
steps: samples.await,
cfg_scale: prompt_guidance.await,
width: res.x,
height: res.y,
restore_faces: improve_faces.await,
tiling: tiling.await,
negative_prompt: negative_prompt.await,
sampler_index,
};
let request_builder = if adapt_input_image.await {
let base64_data = image_to_base64(image)?;
let request_data = ImaginateImageToImageRequest {
common: common_request_data,
override_settings: Default::default(),
init_images: vec![base64_data],
denoising_strength: image_creativity.await * 0.01,
mask: None,
};
let url = join_url(&base_url, SDAPI_IMAGE_TO_IMAGE)?;
client.post(url).json(&request_data)
} else {
let request_data = ImaginateTextToImageRequest {
common: common_request_data,
override_settings: Default::default(),
};
let url = join_url(&base_url, SDAPI_TEXT_TO_IMAGE)?;
client.post(url).json(&request_data)
};
let request = request_builder.header("Accept", "*/*").build().map_err(Error::RequestBuild)?;
let (response_future, abort_handle) = futures::future::abortable(client.execute(request));
controller.set_termination_handle(Box::new(ImaginateFutureAbortHandle(abort_handle)));
let progress_url = join_url(&base_url, SDAPI_PROGRESS)?;
futures::pin_mut!(response_future);
let response = loop {
let progress_request = new_get_request(&client, progress_url.clone())?;
let progress_response_future = client.execute(progress_request).and_then(|response| response.json());
futures::pin_mut!(progress_response_future);
response_future = match futures::future::select(response_future, progress_response_future).await {
Either::Left((response, _)) => break response,
Either::Right((progress, response_future)) => {
if let Ok(ProgressResponse { progress }) = progress {
set_progress(ImaginateStatus::Generating(progress * 100.));
}
response_future
}
};
};
let response = match response {
Ok(response) => response.and_then(reqwest::Response::error_for_status).map_err(Error::Request)?,
Err(_aborted) => {
set_progress(ImaginateStatus::Terminating);
let url = join_url(&base_url, SDAPI_TERMINATE)?;
let request = client.post(url).build().map_err(Error::RequestBuild)?;
// The user probably doesn't really care if the server side was really aborted or if there was an network error.
// So we fool them that the request was terminated if the termination request in reality failed.
let _ = client.execute(request).await.and_then(reqwest::Response::error_for_status).map_err(Error::TerminationFailed)?;
return Err(Error::Terminated);
}
};
set_progress(ImaginateStatus::Uploading);
let ImageResponse { images } = response.json().await.map_err(Error::ResponseFormat)?;
let result = images.into_iter().next().ok_or(Error::NoImage).and_then(base64_to_image)?;
set_progress(ImaginateStatus::ReadyDone);
Ok(result)
}
fn image_to_base64<P: Pixel>(image: Image<P>) -> Result<String, Error> {
use base64::prelude::*;
let Image { width, height, data, .. } = image;
fn cast_with_f32<S: Pixel, D: image::Pixel<Subpixel = f32>>(data: Vec<S>, width: u32, height: u32) -> Result<DynamicImage, Error>
where
DynamicImage: From<ImageBuffer<D, Vec<f32>>>,
{
ImageBuffer::<D, Vec<f32>>::from_raw(width, height, bytemuck::cast_vec(data))
.ok_or(Error::InconsistentImageSize)
.map(Into::into)
}
let image: DynamicImage = match TypeId::of::<P>() {
id if id == TypeId::of::<Color>() => cast_with_f32::<_, image::Rgba<f32>>(data, width, height)?
// we need to do this cast, because png does not support rgba32f
.to_rgba16().into(),
id if id == TypeId::of::<Luma>() => cast_with_f32::<_, image::Luma<f32>>(data, width, height)?
// we need to do this cast, because png does not support luma32f
.to_luma16().into(),
_ => return Err(Error::UnsupportedPixelType(core::any::type_name::<P>())),
};
let mut png_data = std::io::Cursor::new(vec![]);
image.write_to(&mut png_data, ImageFormat::Png).map_err(Error::ImageEncode)?;
Ok(BASE64_STANDARD.encode(png_data.into_inner()))
}
fn base64_to_image<D: AsRef<[u8]>, P: Pixel>(base64_data: D) -> Result<Image<P>, Error> {
use base64::prelude::*;
let png_data = BASE64_STANDARD.decode(base64_data).map_err(Error::Base64Decode)?;
let dyn_image = image::load_from_memory_with_format(&png_data, image::ImageFormat::Png).map_err(Error::ImageDecode)?;
let (width, height) = (dyn_image.width(), dyn_image.height());
let result_data: Vec<P> = match TypeId::of::<P>() {
id if id == TypeId::of::<Color>() => bytemuck::cast_vec(dyn_image.into_rgba32f().into_raw()),
id if id == TypeId::of::<Luma>() => bytemuck::cast_vec(dyn_image.to_luma32f().into_raw()),
_ => return Err(Error::UnsupportedPixelType(core::any::type_name::<P>())),
};
Ok(Image {
data: result_data,
width,
height,
base64_string: None,
})
}
pub fn pick_safe_imaginate_resolution((width, height): (f64, f64)) -> (u64, u64) {
const NATIVE_MODEL_RESOLUTION: f64 = 512.;
let size = if width * height == 0. { DVec2::splat(NATIVE_MODEL_RESOLUTION) } else { DVec2::new(width, height) };
const MAX_RESOLUTION: u64 = 1000 * 1000;
// This is the maximum width/height that can be obtained
const MAX_DIMENSION: u64 = (MAX_RESOLUTION / 64) & !63;
// Round the resolution to the nearest multiple of 64
let size = (size.round().clamp(DVec2::ZERO, DVec2::splat(MAX_DIMENSION as _)).as_u64vec2() + U64Vec2::splat(32)).max(U64Vec2::splat(64)) & !U64Vec2::splat(63);
let resolution = size.x * size.y;
if resolution > MAX_RESOLUTION {
// Scale down the image, so it is smaller than MAX_RESOLUTION
let scale = (MAX_RESOLUTION as f64 / resolution as f64).sqrt();
let size = size.as_dvec2() * scale;
if size.x < 64. {
// The image is extremely wide
(64, MAX_DIMENSION)
} else if size.y < 64. {
// The image is extremely high
(MAX_DIMENSION, 64)
} else {
// Round down to a multiple of 64, so that the resolution still is smaller than MAX_RESOLUTION
(size.as_u64vec2() & !U64Vec2::splat(63)).into()
}
} else {
size.into()
}
}

View file

@ -305,103 +305,6 @@ fn image_value(_: impl Ctx, _primary: (), image: RasterDataTable<CPU>) -> Raster
image
}
// macro_rules! generate_imaginate_node {
// ($($val:ident: $t:ident: $o:ty,)*) => {
// pub struct ImaginateNode<P: Pixel, E, C, G, $($t,)*> {
// editor_api: E,
// controller: C,
// generation_id: G,
// $($val: $t,)*
// cache: std::sync::Arc<std::sync::Mutex<HashMap<u64, Image<P>>>>,
// last_generation: std::sync::atomic::AtomicU64,
// }
// impl<'e, P: Pixel, E, C, G, $($t,)*> ImaginateNode<P, E, C, G, $($t,)*>
// where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
// E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, &'e WasmEditorApi>>,
// C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
// G: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, u64>>,
// {
// #[allow(clippy::too_many_arguments)]
// pub fn new(editor_api: E, controller: C, $($val: $t,)* generation_id: G ) -> Self {
// Self { editor_api, controller, generation_id, $($val,)* cache: Default::default(), last_generation: std::sync::atomic::AtomicU64::new(u64::MAX) }
// }
// }
// impl<'i, 'e: 'i, P: Pixel + 'i + Hash + Default + Send, E: 'i, C: 'i, G: 'i, $($t: 'i,)*> Node<'i, RasterData<P>> for ImaginateNode<P, E, C, G, $($t,)*>
// where $($t: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, $o>>,)*
// E: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, &'e WasmEditorApi>>,
// C: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, ImaginateController>>,
// G: for<'any_input> Node<'any_input, (), Output = DynFuture<'any_input, u64>>,
// {
// type Output = DynFuture<'i, RasterData<P>>;
// fn eval(&'i self, frame: RasterData<P>) -> Self::Output {
// let controller = self.controller.eval(());
// $(let $val = self.$val.eval(());)*
// use std::hash::Hasher;
// let mut hasher = rustc_hash::FxHasher::default();
// frame.image.hash(&mut hasher);
// let hash = hasher.finish();
// let editor_api = self.editor_api.eval(());
// let cache = self.cache.clone();
// let generation_future = self.generation_id.eval(());
// let last_generation = &self.last_generation;
// Box::pin(async move {
// let controller: ImaginateController = controller.await;
// let generation_id = generation_future.await;
// if generation_id != last_generation.swap(generation_id, std::sync::atomic::Ordering::SeqCst) {
// let image = super::imaginate::imaginate(frame.image, editor_api, controller, $($val,)*).await;
// cache.lock().unwrap().insert(hash, image.clone());
// return wrap_image_frame(image, frame.transform);
// }
// let image = cache.lock().unwrap().get(&hash).cloned().unwrap_or_default();
// return wrap_image_frame(image, frame.transform);
// })
// }
// }
// }
// }
// fn wrap_image_frame<P: Pixel>(image: Image<P>, transform: DAffine2) -> RasterData<P> {
// if !transform.decompose_scale().abs_diff_eq(DVec2::ZERO, 0.00001) {
// RasterData {
// image,
// transform,
// alpha_blending: AlphaBlending::default(),
// }
// } else {
// let resolution = DVec2::new(image.height as f64, image.width as f64);
// RasterData {
// image,
// transform: DAffine2::from_scale_angle_translation(resolution, 0., transform.translation),
// alpha_blending: AlphaBlending::default(),
// }
// }
// }
// generate_imaginate_node! {
// seed: Seed: f64,
// res: Res: Option<DVec2>,
// samples: Samples: u32,
// sampling_method: SamplingMethod: ImaginateSamplingMethod,
// prompt_guidance: PromptGuidance: f64,
// prompt: Prompt: String,
// negative_prompt: NegativePrompt: String,
// adapt_input_image: AdaptInputImage: bool,
// image_creativity: ImageCreativity: f64,
// inpaint: Inpaint: bool,
// mask_blur: MaskBlur: f64,
// mask_starting_fill: MaskStartingFill: ImaginateMaskStartingFill,
// improve_faces: ImproveFaces: bool,
// tiling: Tiling: bool,
// }
#[node_macro::node(category("Raster: Pattern"))]
#[allow(clippy::too_many_arguments)]
fn noise_pattern(