Make Imaginate into a node (#878)

* Simplify document node input defenitions

* Remove imaginate layer

* Imaginate node properties

* Fix serde feature gate

* Add Proc Macro for Protonode implementation

* Fix incorrect type

* Add cargo.toml metadata

* Send imaginate params to frontend

* Fix image_creativity range

* Finish imaginate implementation

* Fix the imaginate draw tool

* Remove node-graph/rpco-macro

* Cargo fmt

* Fix missing workspace member

* Changes to the resolution

* Add checkbox for Imaginate auto resolution; improve Properties panel layouts

And fix bugs in panel resizing

* Implement the Rescale button

* Reorder imports

* Update Rust deps

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
0HyperCube 2022-12-20 22:51:38 +00:00 committed by Keavon Chambers
parent 2f2daa25e9
commit 2732492307
61 changed files with 2249 additions and 2596 deletions

View file

@ -22,6 +22,7 @@ rand_chacha = "0.3.1"
log = "0.4"
serde = { version = "1", features = ["derive", "rc"], optional = true }
glam = { version = "0.22" }
base64 = "0.13"
vulkano = {git = "https://github.com/GraphiteEditor/vulkano", branch = "fix_rust_gpu", optional = true}
bytemuck = {version = "1.8" }

View file

@ -1,8 +1,10 @@
use dyn_any::StaticType;
pub use dyn_any::StaticType;
use dyn_any::{DynAny, Upcast};
use dyn_clone::DynClone;
use glam::DVec2;
use std::sync::Arc;
pub use glam::DVec2;
pub use std::sync::Arc;
pub use crate::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateStatus};
/// A type that is known, allowing serialization (serde::Deserialize is not object safe)
#[derive(Clone, Debug, PartialEq)]
@ -15,13 +17,20 @@ pub enum TaggedValue {
F64(f64),
Bool(bool),
DVec2(DVec2),
OptionalDVec2(Option<DVec2>),
Image(graphene_core::raster::Image),
RcImage(Option<Arc<graphene_core::raster::Image>>),
Color(graphene_core::raster::color::Color),
Subpath(graphene_core::vector::subpath::Subpath),
RcSubpath(Arc<graphene_core::vector::subpath::Subpath>),
ImaginateSamplingMethod(ImaginateSamplingMethod),
ImaginateMaskStartingFill(ImaginateMaskStartingFill),
ImaginateStatus(ImaginateStatus),
LayerPath(Option<Vec<u64>>),
}
impl TaggedValue {
/// Converts to a Box<dyn DynAny> - this isn't very neat but I'm not sure of a better approach
pub fn to_value(self) -> Value {
match self {
TaggedValue::None => Box::new(()),
@ -31,10 +40,16 @@ impl TaggedValue {
TaggedValue::F64(x) => Box::new(x),
TaggedValue::Bool(x) => Box::new(x),
TaggedValue::DVec2(x) => Box::new(x),
TaggedValue::OptionalDVec2(x) => Box::new(x),
TaggedValue::Image(x) => Box::new(x),
TaggedValue::RcImage(x) => Box::new(x),
TaggedValue::Color(x) => Box::new(x),
TaggedValue::Subpath(x) => Box::new(x),
TaggedValue::RcSubpath(x) => Box::new(x),
TaggedValue::ImaginateSamplingMethod(x) => Box::new(x),
TaggedValue::ImaginateMaskStartingFill(x) => Box::new(x),
TaggedValue::ImaginateStatus(x) => Box::new(x),
TaggedValue::LayerPath(x) => Box::new(x),
}
}
}

View file

@ -0,0 +1,285 @@
#[cfg(feature = "serde")]
mod base64_serde {
use serde::{Deserialize, Deserializer, Serializer};
pub fn as_base64<S>(key: &std::sync::Arc<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&base64::encode(key.as_slice()))
}
pub fn from_base64<'a, D>(deserializer: D) -> Result<std::sync::Arc<Vec<u8>>, D::Error>
where
D: Deserializer<'a>,
{
use serde::de::Error;
String::deserialize(deserializer)
.and_then(|string| base64::decode(string).map_err(|err| Error::custom(err.to_string())))
.map(std::sync::Arc::new)
.map_err(serde::de::Error::custom)
}
}
use dyn_any::{DynAny, StaticType};
use glam::DVec2;
use std::fmt::Debug;
#[derive(Clone, PartialEq, Debug, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateInput {
// User-configurable layer parameters
pub seed: u64,
pub samples: u32,
pub sampling_method: ImaginateSamplingMethod,
pub use_img2img: bool,
pub denoising_strength: f64,
pub mask_layer_ref: Option<Vec<u64>>,
pub mask_paint_mode: ImaginateMaskPaintMode,
pub mask_blur_px: u32,
pub mask_fill_content: ImaginateMaskStartingFill,
pub cfg_scale: f64,
pub prompt: String,
pub negative_prompt: String,
pub restore_faces: bool,
pub tiling: bool,
pub image_data: Option<ImaginateImageData>,
pub mime: String,
/// 0 is not started, 100 is complete.
pub percent_complete: f64,
// TODO: Have the browser dispose of this blob URL when this is dropped (like when the layer is deleted)
#[cfg_attr(feature = "serde", serde(skip))]
pub blob_url: Option<String>,
#[cfg_attr(feature = "serde", serde(skip))]
pub status: ImaginateStatus,
#[cfg_attr(feature = "serde", serde(skip))]
pub dimensions: DVec2,
}
#[derive(Default, Debug, Clone, Copy, PartialEq, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ImaginateStatus {
#[default]
Idle,
Beginning,
Uploading(f64),
Generating,
Terminating,
Terminated,
}
#[derive(Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateImageData {
#[cfg_attr(feature = "serde", serde(serialize_with = "base64_serde::as_base64", deserialize_with = "base64_serde::from_base64"))]
pub image_data: std::sync::Arc<Vec<u8>>,
}
impl Debug for ImaginateImageData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("[image data...]")
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateBaseImage {
pub mime: String,
#[serde(rename = "imageData")]
pub image_data: Vec<u8>,
pub size: DVec2,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateMaskImage {
pub svg: String,
pub size: DVec2,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq)]
pub enum ImaginateMaskPaintMode {
#[default]
Inpaint,
Outpaint,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny)]
pub enum ImaginateMaskStartingFill {
#[default]
Fill,
Original,
LatentNoise,
LatentNothing,
}
impl ImaginateMaskStartingFill {
pub fn list() -> [ImaginateMaskStartingFill; 4] {
[
ImaginateMaskStartingFill::Fill,
ImaginateMaskStartingFill::Original,
ImaginateMaskStartingFill::LatentNoise,
ImaginateMaskStartingFill::LatentNothing,
]
}
}
impl std::fmt::Display for ImaginateMaskStartingFill {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ImaginateMaskStartingFill::Fill => write!(f, "Smeared Surroundings"),
ImaginateMaskStartingFill::Original => write!(f, "Original Base Image"),
ImaginateMaskStartingFill::LatentNoise => write!(f, "Randomness (Latent Noise)"),
ImaginateMaskStartingFill::LatentNothing => write!(f, "Neutral (Latent Nothing)"),
}
}
}
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ImaginateSamplingMethod {
#[default]
EulerA,
Euler,
LMS,
Heun,
DPM2,
DPM2A,
DPMPlusPlus2sA,
DPMPlusPlus2m,
DPMFast,
DPMAdaptive,
LMSKarras,
DPM2Karras,
DPM2AKarras,
DPMPlusPlus2sAKarras,
DPMPlusPlus2mKarras,
DDIM,
PLMS,
}
impl ImaginateSamplingMethod {
pub fn api_value(&self) -> &str {
match self {
ImaginateSamplingMethod::EulerA => "Euler a",
ImaginateSamplingMethod::Euler => "Euler",
ImaginateSamplingMethod::LMS => "LMS",
ImaginateSamplingMethod::Heun => "Heun",
ImaginateSamplingMethod::DPM2 => "DPM2",
ImaginateSamplingMethod::DPM2A => "DPM2 a",
ImaginateSamplingMethod::DPMPlusPlus2sA => "DPM++ 2S a",
ImaginateSamplingMethod::DPMPlusPlus2m => "DPM++ 2M",
ImaginateSamplingMethod::DPMFast => "DPM fast",
ImaginateSamplingMethod::DPMAdaptive => "DPM adaptive",
ImaginateSamplingMethod::LMSKarras => "LMS Karras",
ImaginateSamplingMethod::DPM2Karras => "DPM2 Karras",
ImaginateSamplingMethod::DPM2AKarras => "DPM2 a Karras",
ImaginateSamplingMethod::DPMPlusPlus2sAKarras => "DPM++ 2S a Karras",
ImaginateSamplingMethod::DPMPlusPlus2mKarras => "DPM++ 2M Karras",
ImaginateSamplingMethod::DDIM => "DDIM",
ImaginateSamplingMethod::PLMS => "PLMS",
}
}
pub fn list() -> [ImaginateSamplingMethod; 17] {
[
ImaginateSamplingMethod::EulerA,
ImaginateSamplingMethod::Euler,
ImaginateSamplingMethod::LMS,
ImaginateSamplingMethod::Heun,
ImaginateSamplingMethod::DPM2,
ImaginateSamplingMethod::DPM2A,
ImaginateSamplingMethod::DPMPlusPlus2sA,
ImaginateSamplingMethod::DPMPlusPlus2m,
ImaginateSamplingMethod::DPMFast,
ImaginateSamplingMethod::DPMAdaptive,
ImaginateSamplingMethod::LMSKarras,
ImaginateSamplingMethod::DPM2Karras,
ImaginateSamplingMethod::DPM2AKarras,
ImaginateSamplingMethod::DPMPlusPlus2sAKarras,
ImaginateSamplingMethod::DPMPlusPlus2mKarras,
ImaginateSamplingMethod::DDIM,
ImaginateSamplingMethod::PLMS,
]
}
}
impl std::fmt::Display for ImaginateSamplingMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ImaginateSamplingMethod::EulerA => write!(f, "Euler A (Recommended)"),
ImaginateSamplingMethod::Euler => write!(f, "Euler"),
ImaginateSamplingMethod::LMS => write!(f, "LMS"),
ImaginateSamplingMethod::Heun => write!(f, "Heun"),
ImaginateSamplingMethod::DPM2 => write!(f, "DPM2"),
ImaginateSamplingMethod::DPM2A => write!(f, "DPM2 A"),
ImaginateSamplingMethod::DPMPlusPlus2sA => write!(f, "DPM++ 2S a"),
ImaginateSamplingMethod::DPMPlusPlus2m => write!(f, "DPM++ 2M"),
ImaginateSamplingMethod::DPMFast => write!(f, "DPM Fast"),
ImaginateSamplingMethod::DPMAdaptive => write!(f, "DPM Adaptive"),
ImaginateSamplingMethod::LMSKarras => write!(f, "LMS Karras"),
ImaginateSamplingMethod::DPM2Karras => write!(f, "DPM2 Karras"),
ImaginateSamplingMethod::DPM2AKarras => write!(f, "DPM2 A Karras"),
ImaginateSamplingMethod::DPMPlusPlus2sAKarras => write!(f, "DPM++ 2S a Karras"),
ImaginateSamplingMethod::DPMPlusPlus2mKarras => write!(f, "DPM++ 2M Karras"),
ImaginateSamplingMethod::DDIM => write!(f, "DDIM"),
ImaginateSamplingMethod::PLMS => write!(f, "PLMS"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImaginateGenerationParameters {
pub seed: u64,
pub samples: u32,
/// Use `ImaginateSamplingMethod::api_value()` to generate this string
#[cfg_attr(feature = "serde", serde(rename = "samplingMethod"))]
pub sampling_method: String,
#[cfg_attr(feature = "serde", serde(rename = "denoisingStrength"))]
pub image_creativity: Option<f64>,
#[cfg_attr(feature = "serde", serde(rename = "cfgScale"))]
pub text_guidance: f64,
#[cfg_attr(feature = "serde", serde(rename = "prompt"))]
pub text_prompt: String,
#[cfg_attr(feature = "serde", serde(rename = "negativePrompt"))]
pub negative_prompt: String,
pub resolution: (u32, u32),
#[cfg_attr(feature = "serde", serde(rename = "restoreFaces"))]
pub restore_faces: bool,
pub tiling: bool,
}
impl Default for ImaginateInput {
fn default() -> Self {
Self {
seed: 0,
samples: 30,
sampling_method: Default::default(),
use_img2img: false,
denoising_strength: 0.66,
mask_paint_mode: ImaginateMaskPaintMode::default(),
mask_layer_ref: None,
mask_blur_px: 4,
mask_fill_content: ImaginateMaskStartingFill::default(),
cfg_scale: 10.,
prompt: "".into(),
negative_prompt: "".into(),
restore_faces: false,
tiling: false,
image_data: None,
mime: "image/png".into(),
blob_url: None,
percent_complete: 0.,
status: Default::default(),
dimensions: Default::default(),
}
}
}

View file

@ -5,6 +5,7 @@ pub mod document;
pub mod proto;
pub mod executor;
pub mod imaginate_input;
#[cfg(feature = "gpu")]
pub mod gpu;

View file

@ -259,7 +259,6 @@ impl ProtoNetwork {
if temp_marks.contains(&node_id) {
panic!("Cycle detected");
}
info!("Visiting {node_id}");
if let Some(dependencies) = inwards_edges.get(&node_id) {
temp_marks.insert(node_id);
@ -273,7 +272,6 @@ impl ProtoNetwork {
assert!(self.nodes.iter().any(|(id, _)| *id == self.output), "Output id {} does not exist", self.output);
visit(self.output, &mut HashSet::new(), &mut sorted, &inwards_edges);
info!("Sorted order {sorted:?}");
sorted
}
@ -307,7 +305,6 @@ impl ProtoNetwork {
let order = self.topological_sort();
// Map of node ids to indexes (which become the node ids as they are inserted into the borrow stack)
let lookup: HashMap<_, _> = order.iter().enumerate().map(|(pos, id)| (*id, pos as NodeId)).collect();
info!("Order {order:?}");
self.nodes = order
.iter()
.enumerate()
@ -324,7 +321,7 @@ impl ProtoNetwork {
self.nodes.iter_mut().for_each(|(_, node)| {
node.map_ids(|id| *lookup.get(&id).expect("node not found in lookup table"));
});
self.inputs = self.inputs.iter().map(|id| *lookup.get(id).unwrap()).collect();
self.inputs = self.inputs.iter().filter_map(|id| lookup.get(id).copied()).collect();
self.output = *lookup.get(&self.output).unwrap();
}
}

View file

@ -9,9 +9,8 @@ license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
derive = ["graph-proc-macros"]
memoization = ["once_cell"]
default = ["derive", "memoization"]
default = ["memoization"]
gpu = ["graph-craft/gpu", "graphene-core/gpu"]
@ -19,7 +18,6 @@ gpu = ["graph-craft/gpu", "graphene-core/gpu"]
graphene-core = {path = "../gcore", features = ["async", "std" ], default-features = false}
borrow_stack = {path = "../borrow_stack"}
dyn-any = {path = "../../libraries/dyn-any", features = ["derive"]}
graph-proc-macros = {path = "../proc-macro", optional = true}
graph-craft = {path = "../graph-craft"}
bytemuck = {version = "1.8" }
tempfile = "3"
@ -37,6 +35,7 @@ kurbo = { git = "https://github.com/linebender/kurbo.git", features = [
"serde",
] }
glam = { version = "0.22", features = ["serde"] }
node-macro = { path="../node-macro" }
[dependencies.serde]
version = "1.0"

View file

@ -142,6 +142,10 @@ pub fn export_image_node<'n>() -> impl Node<(Image, &'n str), Output = Result<()
})
}
#[derive(Debug, Clone, Copy)]
pub struct GrayscaleNode;
#[node_macro::node_fn(GrayscaleNode)]
fn grayscale_image(mut image: Image) -> Image {
for pixel in &mut image.data {
let avg = (pixel.r() + pixel.g() + pixel.b()) / 3.;
@ -151,21 +155,9 @@ fn grayscale_image(mut image: Image) -> Image {
}
#[derive(Debug, Clone, Copy)]
pub struct GrayscaleNode;
impl Node<Image> for GrayscaleNode {
type Output = Image;
fn eval(self, image: Image) -> Image {
grayscale_image(image)
}
}
impl Node<Image> for &GrayscaleNode {
type Output = Image;
fn eval(self, image: Image) -> Image {
grayscale_image(image)
}
}
pub struct InvertRGBNode;
#[node_macro::node_fn(InvertRGBNode)]
fn invert_image(mut image: Image) -> Image {
for pixel in &mut image.data {
*pixel = Color::from_rgbaf32_unchecked(1. - pixel.r(), 1. - pixel.g(), 1. - pixel.b(), pixel.a());
@ -174,22 +166,15 @@ fn invert_image(mut image: Image) -> Image {
}
#[derive(Debug, Clone, Copy)]
pub struct InvertRGBNode;
impl Node<Image> for InvertRGBNode {
type Output = Image;
fn eval(self, image: Image) -> Image {
invert_image(image)
}
}
impl Node<Image> for &InvertRGBNode {
type Output = Image;
fn eval(self, image: Image) -> Image {
invert_image(image)
}
pub struct HueSaturationNode<Hue, Sat, Lit> {
hue_shift: Hue,
saturation_shift: Sat,
lightness_shift: Lit,
}
fn shift_image_hsl(mut image: Image, hue_shift: f32, saturation_shift: f32, lightness_shift: f32) -> Image {
#[node_macro::node_fn(HueSaturationNode)]
fn shift_image_hsl(mut image: Image, hue_shift: f64, saturation_shift: f64, lightness_shift: f64) -> Image {
let (hue_shift, saturation_shift, lightness_shift) = (hue_shift as f32, saturation_shift as f32, lightness_shift as f32);
for pixel in &mut image.data {
let [hue, saturation, lightness, alpha] = pixel.to_hsla();
*pixel = Color::from_hsla(
@ -203,108 +188,18 @@ fn shift_image_hsl(mut image: Image, hue_shift: f32, saturation_shift: f32, ligh
}
#[derive(Debug, Clone, Copy)]
pub struct HueSaturationNode<Hue, Sat, Lit>
where
Hue: Node<(), Output = f64>,
Sat: Node<(), Output = f64>,
Lit: Node<(), Output = f64>,
{
hue: Hue,
saturation: Sat,
lightness: Lit,
}
impl<Hue, Sat, Lit> Node<Image> for HueSaturationNode<Hue, Sat, Lit>
where
Hue: Node<(), Output = f64>,
Sat: Node<(), Output = f64>,
Lit: Node<(), Output = f64>,
{
type Output = Image;
fn eval(self, image: Image) -> Image {
shift_image_hsl(image, self.hue.eval(()) as f32, self.saturation.eval(()) as f32, self.lightness.eval(()) as f32)
}
}
impl<Hue, Sat, Lit> Node<Image> for &HueSaturationNode<Hue, Sat, Lit>
where
Hue: Node<(), Output = f64> + Copy,
Sat: Node<(), Output = f64> + Copy,
Lit: Node<(), Output = f64> + Copy,
{
type Output = Image;
fn eval(self, image: Image) -> Image {
shift_image_hsl(image, self.hue.eval(()) as f32, self.saturation.eval(()) as f32, self.lightness.eval(()) as f32)
}
}
impl<Hue, Sat, Lit> HueSaturationNode<Hue, Sat, Lit>
where
Hue: Node<(), Output = f64>,
Sat: Node<(), Output = f64>,
Lit: Node<(), Output = f64>,
{
pub fn new(hue: Hue, saturation: Sat, lightness: Lit) -> Self {
Self { hue, saturation, lightness }
}
}
// Copy pasta from https://stackoverflow.com/questions/2976274/adjust-bitmap-image-brightness-contrast-using-c
fn adjust_image_brightness_and_contrast(mut image: Image, brightness_shift: f32, contrast: f32) -> Image {
let factor = (259. * (contrast + 255.)) / (255. * (259. - contrast));
let channel = |channel: f32| ((factor * (channel * 255. + brightness_shift - 128.) + 128.) / 255.).clamp(0., 1.);
for pixel in &mut image.data {
*pixel = Color::from_rgbaf32_unchecked(channel(pixel.r()), channel(pixel.g()), channel(pixel.b()), pixel.a())
}
image
}
#[derive(Debug, Clone, Copy)]
pub struct BrightnessContrastNode<Brightness, Contrast>
where
Brightness: Node<(), Output = f64>,
Contrast: Node<(), Output = f64>,
{
pub struct BrightnessContrastNode<Brightness, Contrast> {
brightness: Brightness,
contrast: Contrast,
}
impl<Brightness, Contrast> Node<Image> for BrightnessContrastNode<Brightness, Contrast>
where
Brightness: Node<(), Output = f64>,
Contrast: Node<(), Output = f64>,
{
type Output = Image;
fn eval(self, image: Image) -> Image {
adjust_image_brightness_and_contrast(image, self.brightness.eval(()) as f32, self.contrast.eval(()) as f32)
}
}
// From https://stackoverflow.com/questions/2976274/adjust-bitmap-image-brightness-contrast-using-c
#[node_macro::node_fn(BrightnessContrastNode)]
fn adjust_image_brightness_and_contrast(mut image: Image, brightness: f64, contrast: f64) -> Image {
let (brightness, contrast) = (brightness as f32, contrast as f32);
let factor = (259. * (contrast + 255.)) / (255. * (259. - contrast));
let channel = |channel: f32| ((factor * (channel * 255. + brightness - 128.) + 128.) / 255.).clamp(0., 1.);
impl<Brightness, Contrast> Node<Image> for &BrightnessContrastNode<Brightness, Contrast>
where
Brightness: Node<(), Output = f64> + Copy,
Contrast: Node<(), Output = f64> + Copy,
{
type Output = Image;
fn eval(self, image: Image) -> Image {
adjust_image_brightness_and_contrast(image, self.brightness.eval(()) as f32, self.contrast.eval(()) as f32)
}
}
impl<Brightness, Contrast> BrightnessContrastNode<Brightness, Contrast>
where
Brightness: Node<(), Output = f64>,
Contrast: Node<(), Output = f64>,
{
pub fn new(brightness: Brightness, contrast: Contrast) -> Self {
Self { brightness, contrast }
}
}
// https://www.dfstudios.co.uk/articles/programming/image-programming-algorithms/image-processing-algorithms-part-6-gamma-correction/
fn image_gamma(mut image: Image, gamma: f32) -> Image {
let inverse_gamma = 1. / gamma;
let channel = |channel: f32| channel.powf(inverse_gamma);
for pixel in &mut image.data {
*pixel = Color::from_rgbaf32_unchecked(channel(pixel.r()), channel(pixel.g()), channel(pixel.b()), pixel.a())
}
@ -312,36 +207,44 @@ fn image_gamma(mut image: Image, gamma: f32) -> Image {
}
#[derive(Debug, Clone, Copy)]
pub struct GammaNode<N: Node<(), Output = f64>>(N);
impl<N: Node<(), Output = f64>> Node<Image> for GammaNode<N> {
type Output = Image;
fn eval(self, image: Image) -> Image {
image_gamma(image, self.0.eval(()) as f32)
}
}
impl<N: Node<(), Output = f64> + Copy> Node<Image> for &GammaNode<N> {
type Output = Image;
fn eval(self, image: Image) -> Image {
image_gamma(image, self.0.eval(()) as f32)
}
pub struct GammaNode<G> {
gamma: G,
}
impl<N: Node<(), Output = f64> + Copy> GammaNode<N> {
pub fn new(node: N) -> Self {
Self(node)
// https://www.dfstudios.co.uk/articles/programming/image-programming-algorithms/image-processing-algorithms-part-6-gamma-correction/
#[node_macro::node_fn(GammaNode)]
fn image_gamma(mut image: Image, gamma: f64) -> Image {
let inverse_gamma = 1. / gamma;
let channel = |channel: f32| channel.powf(inverse_gamma as f32);
for pixel in &mut image.data {
*pixel = Color::from_rgbaf32_unchecked(channel(pixel.r()), channel(pixel.g()), channel(pixel.b()), pixel.a())
}
image
}
fn image_opacity(mut image: Image, opacity_multiplier: f32) -> Image {
#[derive(Debug, Clone, Copy)]
pub struct OpacityNode<O> {
opacity_multiplier: O,
}
#[node_macro::node_fn(OpacityNode)]
fn image_opacity(mut image: Image, opacity_multiplier: f64) -> Image {
let opacity_multiplier = opacity_multiplier as f32;
for pixel in &mut image.data {
*pixel = Color::from_rgbaf32_unchecked(pixel.r(), pixel.g(), pixel.b(), pixel.a() * opacity_multiplier)
}
image
}
#[derive(Debug, Clone, Copy)]
pub struct PosterizeNode<P> {
posterize_value: P,
}
// Based on http://www.axiomx.com/posterize.htm
fn posterize(mut image: Image, posterize_value: f32) -> Image {
#[node_macro::node_fn(PosterizeNode)]
fn posterize(mut image: Image, posterize_value: f64) -> Image {
let posterize_value = posterize_value as f32;
let number_of_areas = posterize_value.recip();
let size_of_areas = (posterize_value - 1.).recip();
let channel = |channel: f32| (channel / number_of_areas).floor() * size_of_areas;
@ -351,9 +254,15 @@ fn posterize(mut image: Image, posterize_value: f32) -> Image {
image
}
#[derive(Debug, Clone, Copy)]
pub struct ExposureNode<E> {
exposure: E,
}
// Based on https://stackoverflow.com/questions/12166117/what-is-the-math-behind-exposure-adjustment-on-photoshop
fn exposure(mut image: Image, exposure: f32) -> Image {
let multiplier = 2f32.powf(exposure);
#[node_macro::node_fn(ExposureNode)]
fn exposure(mut image: Image, exposure: f64) -> Image {
let multiplier = 2f32.powf(exposure as f32);
let channel = |channel: f32| channel * multiplier;
for pixel in &mut image.data {
*pixel = Color::from_rgbaf32_unchecked(channel(pixel.r()), channel(pixel.g()), channel(pixel.b()), pixel.a())
@ -362,69 +271,15 @@ fn exposure(mut image: Image, exposure: f32) -> Image {
}
#[derive(Debug, Clone, Copy)]
pub struct PosterizeNode<N: Node<(), Output = f64>>(N);
impl<N: Node<(), Output = f64>> Node<Image> for PosterizeNode<N> {
type Output = Image;
fn eval(self, image: Image) -> Image {
posterize(image, self.0.eval(()) as f32)
}
}
impl<N: Node<(), Output = f64> + Copy> Node<Image> for &PosterizeNode<N> {
type Output = Image;
fn eval(self, image: Image) -> Image {
posterize(image, self.0.eval(()) as f32)
}
pub struct ImaginateNode<E> {
cached: E,
}
impl<N: Node<(), Output = f64> + Copy> PosterizeNode<N> {
pub fn new(node: N) -> Self {
Self(node)
}
}
#[derive(Debug, Clone, Copy)]
pub struct OpacityNode<N: Node<(), Output = f64>>(N);
impl<N: Node<(), Output = f64>> Node<Image> for OpacityNode<N> {
type Output = Image;
fn eval(self, image: Image) -> Image {
image_opacity(image, self.0.eval(()) as f32)
}
}
impl<N: Node<(), Output = f64> + Copy> Node<Image> for &OpacityNode<N> {
type Output = Image;
fn eval(self, image: Image) -> Image {
image_opacity(image, self.0.eval(()) as f32)
}
}
impl<N: Node<(), Output = f64> + Copy> OpacityNode<N> {
pub fn new(node: N) -> Self {
Self(node)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ExposureNode<N: Node<(), Output = f64>>(N);
impl<N: Node<(), Output = f64>> Node<Image> for ExposureNode<N> {
type Output = Image;
fn eval(self, image: Image) -> Image {
exposure(image, self.0.eval(()) as f32)
}
}
impl<N: Node<(), Output = f64> + Copy> Node<Image> for &ExposureNode<N> {
type Output = Image;
fn eval(self, image: Image) -> Image {
exposure(image, self.0.eval(()) as f32)
}
}
impl<N: Node<(), Output = f64> + Copy> ExposureNode<N> {
pub fn new(node: N) -> Self {
Self(node)
}
// Based on https://stackoverflow.com/questions/12166117/what-is-the-math-behind-exposure-adjustment-on-photoshop
#[node_macro::node_fn(ImaginateNode)]
fn imaginate(image: Image, cached: Option<std::sync::Arc<graphene_core::raster::Image>>) -> Image {
info!("Imaginating image with {} pixels", image.data.len());
cached.map(|mut x| std::sync::Arc::make_mut(&mut x).clone()).unwrap_or(image)
}
#[cfg(test)]

View file

@ -190,7 +190,6 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
})
}),
(NodeIdentifier::new("graphene_core::raster::BrightenColorNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
info!("proto node {:?}", proto_node);
stack.push_fn(|nodes| {
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Brighten Color Node constructed with out brightness input node") };
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
@ -272,7 +271,6 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
(NodeIdentifier::new("graphene_std::raster::MapImageNode", &[]), |proto_node, stack| {
if let ConstructionArgs::Nodes(operation_node_id) = proto_node.construction_args {
stack.push_fn(move |nodes| {
info!("Map image Depending upon id {:?}", operation_node_id);
let operation_node = nodes.get(operation_node_id[0] as usize).unwrap();
let operation_node: DowncastBothNode<_, Color, Color> = DowncastBothNode::new(operation_node);
let map_node = DynAnyNode::new(graphene_std::raster::MapImageNode::new(operation_node));
@ -403,6 +401,21 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
}
})
}),
(NodeIdentifier::new("graphene_std::raster::ImaginateNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
stack.push_fn(move |nodes| {
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("ImaginateNode constructed without inputs") };
let value: DowncastBothNode<_, (), Option<std::sync::Arc<graphene_core::raster::Image>>> = DowncastBothNode::new(nodes.get(construction_nodes[15] as usize).unwrap());
let node = DynAnyNode::new(graphene_std::raster::ImaginateNode::new(value));
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(node).into_type_erased()
} else {
node.into_type_erased()
}
})
}),
(NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]), |_proto_node, stack| {
stack.push_fn(|_nodes| {
let image = FnNode::new(|s: &str| graphene_std::raster::image_node::<&str>().eval(s).unwrap());

View file

@ -0,0 +1,20 @@
[package]
name = "node-macro"
publish = false
version = "0.0.0"
rust-version = "1.65.0"
authors = ["Graphite Authors <contact@graphite.rs>"]
edition = "2021"
readme = "../../README.md"
homepage = "https://graphite.rs"
repository = "https://github.com/GraphiteEditor/Graphite"
license = "Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
proc-macro = true
[dependencies]
syn = { version = "1.0", features = ["full"] }
quote = "1.0"

View file

@ -0,0 +1,83 @@
use proc_macro::TokenStream;
use quote::{format_ident, ToTokens};
use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatIdent, ReturnType};
#[proc_macro_attribute]
pub fn node_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
let node_name = parse_macro_input!(attr as Ident);
let function = parse_macro_input!(item as ItemFn);
let function_name = &function.sig.ident;
let mut function_inputs = function.sig.inputs.iter().filter_map(|arg| if let FnArg::Typed(typed_arg) = arg { Some(typed_arg) } else { None });
// Extract primary input as first argument
let primary_input = function_inputs.next().expect("Primary input required - set to `()` if not needed.");
let Pat::Ident(PatIdent{ident: primary_input_ident,..} ) =&*primary_input.pat else{
panic!("Expected ident as primary input.");
};
let primary_input_ty = &primary_input.ty;
// Extract secondary inputs as all other arguments
let secondary_inputs = function_inputs.collect::<Vec<_>>();
let secondary_idents = secondary_inputs
.iter()
.map(|input| {
let Pat::Ident(PatIdent { ident: primary_input_ident,.. }) = &*input.pat else { panic!("Expected ident for secondary input."); };
primary_input_ident
})
.collect::<Vec<_>>();
// Extract the output type of the entire node - `()` by default
let output = if let ReturnType::Type(_, ty) = &function.sig.output {
ty.to_token_stream()
} else {
quote::quote!(())
};
// Generics are simply `S0` through to `Sn-1` where n is the number of secondary inputs
let generics = (0..secondary_inputs.len()).map(|x| format_ident!("S{x}")).collect::<Vec<_>>();
// Bindings for all of the above generics to a node with an input of `()` and an output of the type in the function
let where_clause = secondary_inputs
.iter()
.zip(&generics)
.map(|(ty, name)| {
let ty = &ty.ty;
quote::quote!(#name: Node<(), Output = #ty>)
})
.collect::<Vec<_>>();
quote::quote! {
#function
impl <#(#generics),*> Node<#primary_input_ty> for #node_name<#(#generics),*>
where
#(#where_clause),* {
type Output = #output;
fn eval(self, #primary_input_ident: #primary_input_ty) -> #output{
#function_name(#primary_input_ident #(, self.#secondary_idents.eval(()))*)
}
}
impl <#(#generics),*> Node<#primary_input_ty> for &#node_name<#(#generics),*>
where
#(#where_clause + Copy),* {
type Output = #output;
fn eval(self, #primary_input_ident: #primary_input_ty) -> #output{
#function_name(#primary_input_ident #(, self.#secondary_idents.eval(()))*)
}
}
impl <#(#generics),*> #node_name<#(#generics),*>
where
#(#where_clause + Copy),* {
pub fn new(#(#secondary_idents: #generics),*) -> Self{
Self{
#(#secondary_idents),*
}
}
}
}
.into()
}

View file

@ -1,18 +0,0 @@
[package]
name = "graph-proc-macros"
version = "0.1.0"
authors = ["Graphite Authors <contact@graphite.rs>"]
edition = "2021"
publish = false
license = "MIT OR Apache-2.0"
[lib]
path = "src/lib.rs"
proc-macro = true
[dependencies]
proc-macro2 = "1.0"
proc_macro_roids = "0.7"
syn = { version = "1.0", features = ["full"] }
quote = "1.0"
graphene-core = {path = "../gcore"}

View file

@ -1,98 +0,0 @@
use proc_macro::TokenStream;
use proc_macro_roids::*;
use quote::{quote, ToTokens};
use syn::punctuated::Punctuated;
use syn::{parse_macro_input, FnArg, ItemFn, Pat, Type};
fn extract_type(a: FnArg) -> Type {
match a {
FnArg::Typed(p) => *p.ty, // notice `ty` instead of `pat`
_ => panic!("Not supported on types with `self`!"),
}
}
fn extract_arg_types(fn_args: Punctuated<FnArg, syn::token::Comma>) -> Vec<Type> {
fn_args.into_iter().map(extract_type).collect::<Vec<_>>()
}
fn extract_arg_idents(fn_args: Punctuated<FnArg, syn::token::Comma>) -> Vec<Pat> {
fn_args.into_iter().map(extract_arg_pat).collect::<Vec<_>>()
}
fn extract_arg_pat(a: FnArg) -> Pat {
match a {
FnArg::Typed(p) => *p.pat,
_ => panic!("Not supported on types with `self`!"),
}
}
#[proc_macro_attribute] // 2
pub fn to_node(_attr: TokenStream, item: TokenStream) -> TokenStream {
let string = item.to_string();
let item2 = item;
let parsed = parse_macro_input!(item2 as ItemFn); // 3
//item.extend(generate_to_string(parsed, string)); // 4
//item
generate_to_string(parsed, string)
}
fn generate_to_string(parsed: ItemFn, string: String) -> TokenStream {
let whole_function = parsed.clone();
//let fn_body = parsed.block; // function body
let sig = parsed.sig; // function signature
//let vis = parsed.vis; // visibility, pub or not
let generics = sig.generics;
let fn_args = sig.inputs; // comma separated args
let fn_return_type = sig.output; // return type
let fn_name = sig.ident; // function name/identifier
let idents = extract_arg_idents(fn_args.clone());
let types = extract_arg_types(fn_args);
let types = types.iter().map(|t| t.to_token_stream()).collect::<Vec<_>>();
let idents = idents.iter().map(|t| t.to_token_stream()).collect::<Vec<_>>();
let _const_idents = idents
.iter()
.map(|t| {
let name = t.to_string().to_uppercase();
quote! {#name}
})
.collect::<Vec<_>>();
let node_fn_name = fn_name.append("_node");
let struct_name = fn_name.append("_input");
let return_type_string = fn_return_type.to_token_stream().to_string().replace("->", "");
let arg_type_string = types.iter().map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
let error = format!("called {} with the wrong type", fn_name);
let x = quote! {
//#whole_function
mod #fn_name {
#[derive(Copy, Clone)]
type F32Node<'n> = &'n (dyn Node<'n, (), Output = &'n (dyn Any + 'static)> + 'n);
struct #struct_name {
#(#idents: #types,)*
}
impl Node for #struct_name {
}
}
fn #node_fn_name #generics() -> Node<'static> {
Node { func: Box::new(move |x| {
let args = x.downcast::<(#(#types,)*)>().expect(#error);
let (#(#idents,)*) = *args;
#whole_function
Box::new(#fn_name(#(#idents,)*))
}),
code: #string.to_string(),
return_type: #return_type_string.trim().to_string(),
args: format!("({})",#arg_type_string.trim()),
position: (0., 0.),
}
}
};
//panic!("{}\n{:?}", x.to_string(), x);
x.into()
}