mirror of
https://github.com/GraphiteEditor/Graphite.git
synced 2025-07-16 04:05:00 +00:00
Make Imaginate into a node (#878)
* Simplify document node input defenitions * Remove imaginate layer * Imaginate node properties * Fix serde feature gate * Add Proc Macro for Protonode implementation * Fix incorrect type * Add cargo.toml metadata * Send imaginate params to frontend * Fix image_creativity range * Finish imaginate implementation * Fix the imaginate draw tool * Remove node-graph/rpco-macro * Cargo fmt * Fix missing workspace member * Changes to the resolution * Add checkbox for Imaginate auto resolution; improve Properties panel layouts And fix bugs in panel resizing * Implement the Rescale button * Reorder imports * Update Rust deps Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
parent
2f2daa25e9
commit
2732492307
61 changed files with 2249 additions and 2596 deletions
|
@ -22,6 +22,7 @@ rand_chacha = "0.3.1"
|
|||
log = "0.4"
|
||||
serde = { version = "1", features = ["derive", "rc"], optional = true }
|
||||
glam = { version = "0.22" }
|
||||
base64 = "0.13"
|
||||
|
||||
vulkano = {git = "https://github.com/GraphiteEditor/vulkano", branch = "fix_rust_gpu", optional = true}
|
||||
bytemuck = {version = "1.8" }
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use dyn_any::StaticType;
|
||||
pub use dyn_any::StaticType;
|
||||
use dyn_any::{DynAny, Upcast};
|
||||
use dyn_clone::DynClone;
|
||||
use glam::DVec2;
|
||||
use std::sync::Arc;
|
||||
pub use glam::DVec2;
|
||||
pub use std::sync::Arc;
|
||||
|
||||
pub use crate::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateStatus};
|
||||
|
||||
/// A type that is known, allowing serialization (serde::Deserialize is not object safe)
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
|
@ -15,13 +17,20 @@ pub enum TaggedValue {
|
|||
F64(f64),
|
||||
Bool(bool),
|
||||
DVec2(DVec2),
|
||||
OptionalDVec2(Option<DVec2>),
|
||||
Image(graphene_core::raster::Image),
|
||||
RcImage(Option<Arc<graphene_core::raster::Image>>),
|
||||
Color(graphene_core::raster::color::Color),
|
||||
Subpath(graphene_core::vector::subpath::Subpath),
|
||||
RcSubpath(Arc<graphene_core::vector::subpath::Subpath>),
|
||||
ImaginateSamplingMethod(ImaginateSamplingMethod),
|
||||
ImaginateMaskStartingFill(ImaginateMaskStartingFill),
|
||||
ImaginateStatus(ImaginateStatus),
|
||||
LayerPath(Option<Vec<u64>>),
|
||||
}
|
||||
|
||||
impl TaggedValue {
|
||||
/// Converts to a Box<dyn DynAny> - this isn't very neat but I'm not sure of a better approach
|
||||
pub fn to_value(self) -> Value {
|
||||
match self {
|
||||
TaggedValue::None => Box::new(()),
|
||||
|
@ -31,10 +40,16 @@ impl TaggedValue {
|
|||
TaggedValue::F64(x) => Box::new(x),
|
||||
TaggedValue::Bool(x) => Box::new(x),
|
||||
TaggedValue::DVec2(x) => Box::new(x),
|
||||
TaggedValue::OptionalDVec2(x) => Box::new(x),
|
||||
TaggedValue::Image(x) => Box::new(x),
|
||||
TaggedValue::RcImage(x) => Box::new(x),
|
||||
TaggedValue::Color(x) => Box::new(x),
|
||||
TaggedValue::Subpath(x) => Box::new(x),
|
||||
TaggedValue::RcSubpath(x) => Box::new(x),
|
||||
TaggedValue::ImaginateSamplingMethod(x) => Box::new(x),
|
||||
TaggedValue::ImaginateMaskStartingFill(x) => Box::new(x),
|
||||
TaggedValue::ImaginateStatus(x) => Box::new(x),
|
||||
TaggedValue::LayerPath(x) => Box::new(x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
285
node-graph/graph-craft/src/imaginate_input.rs
Normal file
285
node-graph/graph-craft/src/imaginate_input.rs
Normal file
|
@ -0,0 +1,285 @@
|
|||
#[cfg(feature = "serde")]
|
||||
mod base64_serde {
|
||||
use serde::{Deserialize, Deserializer, Serializer};
|
||||
|
||||
pub fn as_base64<S>(key: &std::sync::Arc<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&base64::encode(key.as_slice()))
|
||||
}
|
||||
|
||||
pub fn from_base64<'a, D>(deserializer: D) -> Result<std::sync::Arc<Vec<u8>>, D::Error>
|
||||
where
|
||||
D: Deserializer<'a>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
String::deserialize(deserializer)
|
||||
.and_then(|string| base64::decode(string).map_err(|err| Error::custom(err.to_string())))
|
||||
.map(std::sync::Arc::new)
|
||||
.map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
use dyn_any::{DynAny, StaticType};
|
||||
use glam::DVec2;
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateInput {
|
||||
// User-configurable layer parameters
|
||||
pub seed: u64,
|
||||
pub samples: u32,
|
||||
pub sampling_method: ImaginateSamplingMethod,
|
||||
pub use_img2img: bool,
|
||||
pub denoising_strength: f64,
|
||||
pub mask_layer_ref: Option<Vec<u64>>,
|
||||
pub mask_paint_mode: ImaginateMaskPaintMode,
|
||||
pub mask_blur_px: u32,
|
||||
pub mask_fill_content: ImaginateMaskStartingFill,
|
||||
pub cfg_scale: f64,
|
||||
pub prompt: String,
|
||||
pub negative_prompt: String,
|
||||
pub restore_faces: bool,
|
||||
pub tiling: bool,
|
||||
|
||||
pub image_data: Option<ImaginateImageData>,
|
||||
pub mime: String,
|
||||
/// 0 is not started, 100 is complete.
|
||||
pub percent_complete: f64,
|
||||
|
||||
// TODO: Have the browser dispose of this blob URL when this is dropped (like when the layer is deleted)
|
||||
#[cfg_attr(feature = "serde", serde(skip))]
|
||||
pub blob_url: Option<String>,
|
||||
#[cfg_attr(feature = "serde", serde(skip))]
|
||||
pub status: ImaginateStatus,
|
||||
#[cfg_attr(feature = "serde", serde(skip))]
|
||||
pub dimensions: DVec2,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum ImaginateStatus {
|
||||
#[default]
|
||||
Idle,
|
||||
Beginning,
|
||||
Uploading(f64),
|
||||
Generating,
|
||||
Terminating,
|
||||
Terminated,
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateImageData {
|
||||
#[cfg_attr(feature = "serde", serde(serialize_with = "base64_serde::as_base64", deserialize_with = "base64_serde::from_base64"))]
|
||||
pub image_data: std::sync::Arc<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl Debug for ImaginateImageData {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("[image data...]")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateBaseImage {
|
||||
pub mime: String,
|
||||
#[serde(rename = "imageData")]
|
||||
pub image_data: Vec<u8>,
|
||||
pub size: DVec2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateMaskImage {
|
||||
pub svg: String,
|
||||
pub size: DVec2,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq)]
|
||||
pub enum ImaginateMaskPaintMode {
|
||||
#[default]
|
||||
Inpaint,
|
||||
Outpaint,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny)]
|
||||
pub enum ImaginateMaskStartingFill {
|
||||
#[default]
|
||||
Fill,
|
||||
Original,
|
||||
LatentNoise,
|
||||
LatentNothing,
|
||||
}
|
||||
|
||||
impl ImaginateMaskStartingFill {
|
||||
pub fn list() -> [ImaginateMaskStartingFill; 4] {
|
||||
[
|
||||
ImaginateMaskStartingFill::Fill,
|
||||
ImaginateMaskStartingFill::Original,
|
||||
ImaginateMaskStartingFill::LatentNoise,
|
||||
ImaginateMaskStartingFill::LatentNothing,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ImaginateMaskStartingFill {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ImaginateMaskStartingFill::Fill => write!(f, "Smeared Surroundings"),
|
||||
ImaginateMaskStartingFill::Original => write!(f, "Original Base Image"),
|
||||
ImaginateMaskStartingFill::LatentNoise => write!(f, "Randomness (Latent Noise)"),
|
||||
ImaginateMaskStartingFill::LatentNothing => write!(f, "Neutral (Latent Nothing)"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum ImaginateSamplingMethod {
|
||||
#[default]
|
||||
EulerA,
|
||||
Euler,
|
||||
LMS,
|
||||
Heun,
|
||||
DPM2,
|
||||
DPM2A,
|
||||
DPMPlusPlus2sA,
|
||||
DPMPlusPlus2m,
|
||||
DPMFast,
|
||||
DPMAdaptive,
|
||||
LMSKarras,
|
||||
DPM2Karras,
|
||||
DPM2AKarras,
|
||||
DPMPlusPlus2sAKarras,
|
||||
DPMPlusPlus2mKarras,
|
||||
DDIM,
|
||||
PLMS,
|
||||
}
|
||||
|
||||
impl ImaginateSamplingMethod {
|
||||
pub fn api_value(&self) -> &str {
|
||||
match self {
|
||||
ImaginateSamplingMethod::EulerA => "Euler a",
|
||||
ImaginateSamplingMethod::Euler => "Euler",
|
||||
ImaginateSamplingMethod::LMS => "LMS",
|
||||
ImaginateSamplingMethod::Heun => "Heun",
|
||||
ImaginateSamplingMethod::DPM2 => "DPM2",
|
||||
ImaginateSamplingMethod::DPM2A => "DPM2 a",
|
||||
ImaginateSamplingMethod::DPMPlusPlus2sA => "DPM++ 2S a",
|
||||
ImaginateSamplingMethod::DPMPlusPlus2m => "DPM++ 2M",
|
||||
ImaginateSamplingMethod::DPMFast => "DPM fast",
|
||||
ImaginateSamplingMethod::DPMAdaptive => "DPM adaptive",
|
||||
ImaginateSamplingMethod::LMSKarras => "LMS Karras",
|
||||
ImaginateSamplingMethod::DPM2Karras => "DPM2 Karras",
|
||||
ImaginateSamplingMethod::DPM2AKarras => "DPM2 a Karras",
|
||||
ImaginateSamplingMethod::DPMPlusPlus2sAKarras => "DPM++ 2S a Karras",
|
||||
ImaginateSamplingMethod::DPMPlusPlus2mKarras => "DPM++ 2M Karras",
|
||||
ImaginateSamplingMethod::DDIM => "DDIM",
|
||||
ImaginateSamplingMethod::PLMS => "PLMS",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list() -> [ImaginateSamplingMethod; 17] {
|
||||
[
|
||||
ImaginateSamplingMethod::EulerA,
|
||||
ImaginateSamplingMethod::Euler,
|
||||
ImaginateSamplingMethod::LMS,
|
||||
ImaginateSamplingMethod::Heun,
|
||||
ImaginateSamplingMethod::DPM2,
|
||||
ImaginateSamplingMethod::DPM2A,
|
||||
ImaginateSamplingMethod::DPMPlusPlus2sA,
|
||||
ImaginateSamplingMethod::DPMPlusPlus2m,
|
||||
ImaginateSamplingMethod::DPMFast,
|
||||
ImaginateSamplingMethod::DPMAdaptive,
|
||||
ImaginateSamplingMethod::LMSKarras,
|
||||
ImaginateSamplingMethod::DPM2Karras,
|
||||
ImaginateSamplingMethod::DPM2AKarras,
|
||||
ImaginateSamplingMethod::DPMPlusPlus2sAKarras,
|
||||
ImaginateSamplingMethod::DPMPlusPlus2mKarras,
|
||||
ImaginateSamplingMethod::DDIM,
|
||||
ImaginateSamplingMethod::PLMS,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ImaginateSamplingMethod {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ImaginateSamplingMethod::EulerA => write!(f, "Euler A (Recommended)"),
|
||||
ImaginateSamplingMethod::Euler => write!(f, "Euler"),
|
||||
ImaginateSamplingMethod::LMS => write!(f, "LMS"),
|
||||
ImaginateSamplingMethod::Heun => write!(f, "Heun"),
|
||||
ImaginateSamplingMethod::DPM2 => write!(f, "DPM2"),
|
||||
ImaginateSamplingMethod::DPM2A => write!(f, "DPM2 A"),
|
||||
ImaginateSamplingMethod::DPMPlusPlus2sA => write!(f, "DPM++ 2S a"),
|
||||
ImaginateSamplingMethod::DPMPlusPlus2m => write!(f, "DPM++ 2M"),
|
||||
ImaginateSamplingMethod::DPMFast => write!(f, "DPM Fast"),
|
||||
ImaginateSamplingMethod::DPMAdaptive => write!(f, "DPM Adaptive"),
|
||||
ImaginateSamplingMethod::LMSKarras => write!(f, "LMS Karras"),
|
||||
ImaginateSamplingMethod::DPM2Karras => write!(f, "DPM2 Karras"),
|
||||
ImaginateSamplingMethod::DPM2AKarras => write!(f, "DPM2 A Karras"),
|
||||
ImaginateSamplingMethod::DPMPlusPlus2sAKarras => write!(f, "DPM++ 2S a Karras"),
|
||||
ImaginateSamplingMethod::DPMPlusPlus2mKarras => write!(f, "DPM++ 2M Karras"),
|
||||
ImaginateSamplingMethod::DDIM => write!(f, "DDIM"),
|
||||
ImaginateSamplingMethod::PLMS => write!(f, "PLMS"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct ImaginateGenerationParameters {
|
||||
pub seed: u64,
|
||||
pub samples: u32,
|
||||
/// Use `ImaginateSamplingMethod::api_value()` to generate this string
|
||||
#[cfg_attr(feature = "serde", serde(rename = "samplingMethod"))]
|
||||
pub sampling_method: String,
|
||||
#[cfg_attr(feature = "serde", serde(rename = "denoisingStrength"))]
|
||||
pub image_creativity: Option<f64>,
|
||||
#[cfg_attr(feature = "serde", serde(rename = "cfgScale"))]
|
||||
pub text_guidance: f64,
|
||||
#[cfg_attr(feature = "serde", serde(rename = "prompt"))]
|
||||
pub text_prompt: String,
|
||||
#[cfg_attr(feature = "serde", serde(rename = "negativePrompt"))]
|
||||
pub negative_prompt: String,
|
||||
pub resolution: (u32, u32),
|
||||
#[cfg_attr(feature = "serde", serde(rename = "restoreFaces"))]
|
||||
pub restore_faces: bool,
|
||||
pub tiling: bool,
|
||||
}
|
||||
|
||||
impl Default for ImaginateInput {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
seed: 0,
|
||||
samples: 30,
|
||||
sampling_method: Default::default(),
|
||||
use_img2img: false,
|
||||
denoising_strength: 0.66,
|
||||
mask_paint_mode: ImaginateMaskPaintMode::default(),
|
||||
mask_layer_ref: None,
|
||||
mask_blur_px: 4,
|
||||
mask_fill_content: ImaginateMaskStartingFill::default(),
|
||||
cfg_scale: 10.,
|
||||
prompt: "".into(),
|
||||
negative_prompt: "".into(),
|
||||
restore_faces: false,
|
||||
tiling: false,
|
||||
|
||||
image_data: None,
|
||||
mime: "image/png".into(),
|
||||
|
||||
blob_url: None,
|
||||
percent_complete: 0.,
|
||||
status: Default::default(),
|
||||
dimensions: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@ pub mod document;
|
|||
pub mod proto;
|
||||
|
||||
pub mod executor;
|
||||
pub mod imaginate_input;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
pub mod gpu;
|
||||
|
|
|
@ -259,7 +259,6 @@ impl ProtoNetwork {
|
|||
if temp_marks.contains(&node_id) {
|
||||
panic!("Cycle detected");
|
||||
}
|
||||
info!("Visiting {node_id}");
|
||||
|
||||
if let Some(dependencies) = inwards_edges.get(&node_id) {
|
||||
temp_marks.insert(node_id);
|
||||
|
@ -273,7 +272,6 @@ impl ProtoNetwork {
|
|||
assert!(self.nodes.iter().any(|(id, _)| *id == self.output), "Output id {} does not exist", self.output);
|
||||
visit(self.output, &mut HashSet::new(), &mut sorted, &inwards_edges);
|
||||
|
||||
info!("Sorted order {sorted:?}");
|
||||
sorted
|
||||
}
|
||||
|
||||
|
@ -307,7 +305,6 @@ impl ProtoNetwork {
|
|||
let order = self.topological_sort();
|
||||
// Map of node ids to indexes (which become the node ids as they are inserted into the borrow stack)
|
||||
let lookup: HashMap<_, _> = order.iter().enumerate().map(|(pos, id)| (*id, pos as NodeId)).collect();
|
||||
info!("Order {order:?}");
|
||||
self.nodes = order
|
||||
.iter()
|
||||
.enumerate()
|
||||
|
@ -324,7 +321,7 @@ impl ProtoNetwork {
|
|||
self.nodes.iter_mut().for_each(|(_, node)| {
|
||||
node.map_ids(|id| *lookup.get(&id).expect("node not found in lookup table"));
|
||||
});
|
||||
self.inputs = self.inputs.iter().map(|id| *lookup.get(id).unwrap()).collect();
|
||||
self.inputs = self.inputs.iter().filter_map(|id| lookup.get(id).copied()).collect();
|
||||
self.output = *lookup.get(&self.output).unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,9 +9,8 @@ license = "MIT OR Apache-2.0"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
derive = ["graph-proc-macros"]
|
||||
memoization = ["once_cell"]
|
||||
default = ["derive", "memoization"]
|
||||
default = ["memoization"]
|
||||
gpu = ["graph-craft/gpu", "graphene-core/gpu"]
|
||||
|
||||
|
||||
|
@ -19,7 +18,6 @@ gpu = ["graph-craft/gpu", "graphene-core/gpu"]
|
|||
graphene-core = {path = "../gcore", features = ["async", "std" ], default-features = false}
|
||||
borrow_stack = {path = "../borrow_stack"}
|
||||
dyn-any = {path = "../../libraries/dyn-any", features = ["derive"]}
|
||||
graph-proc-macros = {path = "../proc-macro", optional = true}
|
||||
graph-craft = {path = "../graph-craft"}
|
||||
bytemuck = {version = "1.8" }
|
||||
tempfile = "3"
|
||||
|
@ -37,6 +35,7 @@ kurbo = { git = "https://github.com/linebender/kurbo.git", features = [
|
|||
"serde",
|
||||
] }
|
||||
glam = { version = "0.22", features = ["serde"] }
|
||||
node-macro = { path="../node-macro" }
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1.0"
|
||||
|
|
|
@ -142,6 +142,10 @@ pub fn export_image_node<'n>() -> impl Node<(Image, &'n str), Output = Result<()
|
|||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GrayscaleNode;
|
||||
|
||||
#[node_macro::node_fn(GrayscaleNode)]
|
||||
fn grayscale_image(mut image: Image) -> Image {
|
||||
for pixel in &mut image.data {
|
||||
let avg = (pixel.r() + pixel.g() + pixel.b()) / 3.;
|
||||
|
@ -151,21 +155,9 @@ fn grayscale_image(mut image: Image) -> Image {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GrayscaleNode;
|
||||
|
||||
impl Node<Image> for GrayscaleNode {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
grayscale_image(image)
|
||||
}
|
||||
}
|
||||
impl Node<Image> for &GrayscaleNode {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
grayscale_image(image)
|
||||
}
|
||||
}
|
||||
pub struct InvertRGBNode;
|
||||
|
||||
#[node_macro::node_fn(InvertRGBNode)]
|
||||
fn invert_image(mut image: Image) -> Image {
|
||||
for pixel in &mut image.data {
|
||||
*pixel = Color::from_rgbaf32_unchecked(1. - pixel.r(), 1. - pixel.g(), 1. - pixel.b(), pixel.a());
|
||||
|
@ -174,22 +166,15 @@ fn invert_image(mut image: Image) -> Image {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct InvertRGBNode;
|
||||
|
||||
impl Node<Image> for InvertRGBNode {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
invert_image(image)
|
||||
}
|
||||
}
|
||||
impl Node<Image> for &InvertRGBNode {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
invert_image(image)
|
||||
}
|
||||
pub struct HueSaturationNode<Hue, Sat, Lit> {
|
||||
hue_shift: Hue,
|
||||
saturation_shift: Sat,
|
||||
lightness_shift: Lit,
|
||||
}
|
||||
|
||||
fn shift_image_hsl(mut image: Image, hue_shift: f32, saturation_shift: f32, lightness_shift: f32) -> Image {
|
||||
#[node_macro::node_fn(HueSaturationNode)]
|
||||
fn shift_image_hsl(mut image: Image, hue_shift: f64, saturation_shift: f64, lightness_shift: f64) -> Image {
|
||||
let (hue_shift, saturation_shift, lightness_shift) = (hue_shift as f32, saturation_shift as f32, lightness_shift as f32);
|
||||
for pixel in &mut image.data {
|
||||
let [hue, saturation, lightness, alpha] = pixel.to_hsla();
|
||||
*pixel = Color::from_hsla(
|
||||
|
@ -203,108 +188,18 @@ fn shift_image_hsl(mut image: Image, hue_shift: f32, saturation_shift: f32, ligh
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct HueSaturationNode<Hue, Sat, Lit>
|
||||
where
|
||||
Hue: Node<(), Output = f64>,
|
||||
Sat: Node<(), Output = f64>,
|
||||
Lit: Node<(), Output = f64>,
|
||||
{
|
||||
hue: Hue,
|
||||
saturation: Sat,
|
||||
lightness: Lit,
|
||||
}
|
||||
|
||||
impl<Hue, Sat, Lit> Node<Image> for HueSaturationNode<Hue, Sat, Lit>
|
||||
where
|
||||
Hue: Node<(), Output = f64>,
|
||||
Sat: Node<(), Output = f64>,
|
||||
Lit: Node<(), Output = f64>,
|
||||
{
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
shift_image_hsl(image, self.hue.eval(()) as f32, self.saturation.eval(()) as f32, self.lightness.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
impl<Hue, Sat, Lit> Node<Image> for &HueSaturationNode<Hue, Sat, Lit>
|
||||
where
|
||||
Hue: Node<(), Output = f64> + Copy,
|
||||
Sat: Node<(), Output = f64> + Copy,
|
||||
Lit: Node<(), Output = f64> + Copy,
|
||||
{
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
shift_image_hsl(image, self.hue.eval(()) as f32, self.saturation.eval(()) as f32, self.lightness.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Hue, Sat, Lit> HueSaturationNode<Hue, Sat, Lit>
|
||||
where
|
||||
Hue: Node<(), Output = f64>,
|
||||
Sat: Node<(), Output = f64>,
|
||||
Lit: Node<(), Output = f64>,
|
||||
{
|
||||
pub fn new(hue: Hue, saturation: Sat, lightness: Lit) -> Self {
|
||||
Self { hue, saturation, lightness }
|
||||
}
|
||||
}
|
||||
|
||||
// Copy pasta from https://stackoverflow.com/questions/2976274/adjust-bitmap-image-brightness-contrast-using-c
|
||||
fn adjust_image_brightness_and_contrast(mut image: Image, brightness_shift: f32, contrast: f32) -> Image {
|
||||
let factor = (259. * (contrast + 255.)) / (255. * (259. - contrast));
|
||||
let channel = |channel: f32| ((factor * (channel * 255. + brightness_shift - 128.) + 128.) / 255.).clamp(0., 1.);
|
||||
|
||||
for pixel in &mut image.data {
|
||||
*pixel = Color::from_rgbaf32_unchecked(channel(pixel.r()), channel(pixel.g()), channel(pixel.b()), pixel.a())
|
||||
}
|
||||
image
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct BrightnessContrastNode<Brightness, Contrast>
|
||||
where
|
||||
Brightness: Node<(), Output = f64>,
|
||||
Contrast: Node<(), Output = f64>,
|
||||
{
|
||||
pub struct BrightnessContrastNode<Brightness, Contrast> {
|
||||
brightness: Brightness,
|
||||
contrast: Contrast,
|
||||
}
|
||||
|
||||
impl<Brightness, Contrast> Node<Image> for BrightnessContrastNode<Brightness, Contrast>
|
||||
where
|
||||
Brightness: Node<(), Output = f64>,
|
||||
Contrast: Node<(), Output = f64>,
|
||||
{
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
adjust_image_brightness_and_contrast(image, self.brightness.eval(()) as f32, self.contrast.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
// From https://stackoverflow.com/questions/2976274/adjust-bitmap-image-brightness-contrast-using-c
|
||||
#[node_macro::node_fn(BrightnessContrastNode)]
|
||||
fn adjust_image_brightness_and_contrast(mut image: Image, brightness: f64, contrast: f64) -> Image {
|
||||
let (brightness, contrast) = (brightness as f32, contrast as f32);
|
||||
let factor = (259. * (contrast + 255.)) / (255. * (259. - contrast));
|
||||
let channel = |channel: f32| ((factor * (channel * 255. + brightness - 128.) + 128.) / 255.).clamp(0., 1.);
|
||||
|
||||
impl<Brightness, Contrast> Node<Image> for &BrightnessContrastNode<Brightness, Contrast>
|
||||
where
|
||||
Brightness: Node<(), Output = f64> + Copy,
|
||||
Contrast: Node<(), Output = f64> + Copy,
|
||||
{
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
adjust_image_brightness_and_contrast(image, self.brightness.eval(()) as f32, self.contrast.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Brightness, Contrast> BrightnessContrastNode<Brightness, Contrast>
|
||||
where
|
||||
Brightness: Node<(), Output = f64>,
|
||||
Contrast: Node<(), Output = f64>,
|
||||
{
|
||||
pub fn new(brightness: Brightness, contrast: Contrast) -> Self {
|
||||
Self { brightness, contrast }
|
||||
}
|
||||
}
|
||||
|
||||
// https://www.dfstudios.co.uk/articles/programming/image-programming-algorithms/image-processing-algorithms-part-6-gamma-correction/
|
||||
fn image_gamma(mut image: Image, gamma: f32) -> Image {
|
||||
let inverse_gamma = 1. / gamma;
|
||||
let channel = |channel: f32| channel.powf(inverse_gamma);
|
||||
for pixel in &mut image.data {
|
||||
*pixel = Color::from_rgbaf32_unchecked(channel(pixel.r()), channel(pixel.g()), channel(pixel.b()), pixel.a())
|
||||
}
|
||||
|
@ -312,36 +207,44 @@ fn image_gamma(mut image: Image, gamma: f32) -> Image {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GammaNode<N: Node<(), Output = f64>>(N);
|
||||
|
||||
impl<N: Node<(), Output = f64>> Node<Image> for GammaNode<N> {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
image_gamma(image, self.0.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
impl<N: Node<(), Output = f64> + Copy> Node<Image> for &GammaNode<N> {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
image_gamma(image, self.0.eval(()) as f32)
|
||||
}
|
||||
pub struct GammaNode<G> {
|
||||
gamma: G,
|
||||
}
|
||||
|
||||
impl<N: Node<(), Output = f64> + Copy> GammaNode<N> {
|
||||
pub fn new(node: N) -> Self {
|
||||
Self(node)
|
||||
// https://www.dfstudios.co.uk/articles/programming/image-programming-algorithms/image-processing-algorithms-part-6-gamma-correction/
|
||||
#[node_macro::node_fn(GammaNode)]
|
||||
fn image_gamma(mut image: Image, gamma: f64) -> Image {
|
||||
let inverse_gamma = 1. / gamma;
|
||||
let channel = |channel: f32| channel.powf(inverse_gamma as f32);
|
||||
for pixel in &mut image.data {
|
||||
*pixel = Color::from_rgbaf32_unchecked(channel(pixel.r()), channel(pixel.g()), channel(pixel.b()), pixel.a())
|
||||
}
|
||||
image
|
||||
}
|
||||
|
||||
fn image_opacity(mut image: Image, opacity_multiplier: f32) -> Image {
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct OpacityNode<O> {
|
||||
opacity_multiplier: O,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(OpacityNode)]
|
||||
fn image_opacity(mut image: Image, opacity_multiplier: f64) -> Image {
|
||||
let opacity_multiplier = opacity_multiplier as f32;
|
||||
for pixel in &mut image.data {
|
||||
*pixel = Color::from_rgbaf32_unchecked(pixel.r(), pixel.g(), pixel.b(), pixel.a() * opacity_multiplier)
|
||||
}
|
||||
image
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PosterizeNode<P> {
|
||||
posterize_value: P,
|
||||
}
|
||||
|
||||
// Based on http://www.axiomx.com/posterize.htm
|
||||
fn posterize(mut image: Image, posterize_value: f32) -> Image {
|
||||
#[node_macro::node_fn(PosterizeNode)]
|
||||
fn posterize(mut image: Image, posterize_value: f64) -> Image {
|
||||
let posterize_value = posterize_value as f32;
|
||||
let number_of_areas = posterize_value.recip();
|
||||
let size_of_areas = (posterize_value - 1.).recip();
|
||||
let channel = |channel: f32| (channel / number_of_areas).floor() * size_of_areas;
|
||||
|
@ -351,9 +254,15 @@ fn posterize(mut image: Image, posterize_value: f32) -> Image {
|
|||
image
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ExposureNode<E> {
|
||||
exposure: E,
|
||||
}
|
||||
|
||||
// Based on https://stackoverflow.com/questions/12166117/what-is-the-math-behind-exposure-adjustment-on-photoshop
|
||||
fn exposure(mut image: Image, exposure: f32) -> Image {
|
||||
let multiplier = 2f32.powf(exposure);
|
||||
#[node_macro::node_fn(ExposureNode)]
|
||||
fn exposure(mut image: Image, exposure: f64) -> Image {
|
||||
let multiplier = 2f32.powf(exposure as f32);
|
||||
let channel = |channel: f32| channel * multiplier;
|
||||
for pixel in &mut image.data {
|
||||
*pixel = Color::from_rgbaf32_unchecked(channel(pixel.r()), channel(pixel.g()), channel(pixel.b()), pixel.a())
|
||||
|
@ -362,69 +271,15 @@ fn exposure(mut image: Image, exposure: f32) -> Image {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PosterizeNode<N: Node<(), Output = f64>>(N);
|
||||
|
||||
impl<N: Node<(), Output = f64>> Node<Image> for PosterizeNode<N> {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
posterize(image, self.0.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
impl<N: Node<(), Output = f64> + Copy> Node<Image> for &PosterizeNode<N> {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
posterize(image, self.0.eval(()) as f32)
|
||||
}
|
||||
pub struct ImaginateNode<E> {
|
||||
cached: E,
|
||||
}
|
||||
|
||||
impl<N: Node<(), Output = f64> + Copy> PosterizeNode<N> {
|
||||
pub fn new(node: N) -> Self {
|
||||
Self(node)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct OpacityNode<N: Node<(), Output = f64>>(N);
|
||||
|
||||
impl<N: Node<(), Output = f64>> Node<Image> for OpacityNode<N> {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
image_opacity(image, self.0.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
impl<N: Node<(), Output = f64> + Copy> Node<Image> for &OpacityNode<N> {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
image_opacity(image, self.0.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Node<(), Output = f64> + Copy> OpacityNode<N> {
|
||||
pub fn new(node: N) -> Self {
|
||||
Self(node)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ExposureNode<N: Node<(), Output = f64>>(N);
|
||||
|
||||
impl<N: Node<(), Output = f64>> Node<Image> for ExposureNode<N> {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
exposure(image, self.0.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
impl<N: Node<(), Output = f64> + Copy> Node<Image> for &ExposureNode<N> {
|
||||
type Output = Image;
|
||||
fn eval(self, image: Image) -> Image {
|
||||
exposure(image, self.0.eval(()) as f32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Node<(), Output = f64> + Copy> ExposureNode<N> {
|
||||
pub fn new(node: N) -> Self {
|
||||
Self(node)
|
||||
}
|
||||
// Based on https://stackoverflow.com/questions/12166117/what-is-the-math-behind-exposure-adjustment-on-photoshop
|
||||
#[node_macro::node_fn(ImaginateNode)]
|
||||
fn imaginate(image: Image, cached: Option<std::sync::Arc<graphene_core::raster::Image>>) -> Image {
|
||||
info!("Imaginating image with {} pixels", image.data.len());
|
||||
cached.map(|mut x| std::sync::Arc::make_mut(&mut x).clone()).unwrap_or(image)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -190,7 +190,6 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_core::raster::BrightenColorNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
info!("proto node {:?}", proto_node);
|
||||
stack.push_fn(|nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("Brighten Color Node constructed with out brightness input node") };
|
||||
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
|
||||
|
@ -272,7 +271,6 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
(NodeIdentifier::new("graphene_std::raster::MapImageNode", &[]), |proto_node, stack| {
|
||||
if let ConstructionArgs::Nodes(operation_node_id) = proto_node.construction_args {
|
||||
stack.push_fn(move |nodes| {
|
||||
info!("Map image Depending upon id {:?}", operation_node_id);
|
||||
let operation_node = nodes.get(operation_node_id[0] as usize).unwrap();
|
||||
let operation_node: DowncastBothNode<_, Color, Color> = DowncastBothNode::new(operation_node);
|
||||
let map_node = DynAnyNode::new(graphene_std::raster::MapImageNode::new(operation_node));
|
||||
|
@ -403,6 +401,21 @@ static NODE_REGISTRY: &[(NodeIdentifier, NodeConstructor)] = &[
|
|||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_std::raster::ImaginateNode", &[concrete!("&TypeErasedNode")]), |proto_node, stack| {
|
||||
stack.push_fn(move |nodes| {
|
||||
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("ImaginateNode constructed without inputs") };
|
||||
let value: DowncastBothNode<_, (), Option<std::sync::Arc<graphene_core::raster::Image>>> = DowncastBothNode::new(nodes.get(construction_nodes[15] as usize).unwrap());
|
||||
|
||||
let node = DynAnyNode::new(graphene_std::raster::ImaginateNode::new(value));
|
||||
|
||||
if let ProtoNodeInput::Node(node_id) = proto_node.input {
|
||||
let pre_node = nodes.get(node_id as usize).unwrap();
|
||||
(pre_node).then(node).into_type_erased()
|
||||
} else {
|
||||
node.into_type_erased()
|
||||
}
|
||||
})
|
||||
}),
|
||||
(NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]), |_proto_node, stack| {
|
||||
stack.push_fn(|_nodes| {
|
||||
let image = FnNode::new(|s: &str| graphene_std::raster::image_node::<&str>().eval(s).unwrap());
|
||||
|
|
20
node-graph/node-macro/Cargo.toml
Normal file
20
node-graph/node-macro/Cargo.toml
Normal file
|
@ -0,0 +1,20 @@
|
|||
[package]
|
||||
name = "node-macro"
|
||||
publish = false
|
||||
version = "0.0.0"
|
||||
rust-version = "1.65.0"
|
||||
authors = ["Graphite Authors <contact@graphite.rs>"]
|
||||
edition = "2021"
|
||||
readme = "../../README.md"
|
||||
homepage = "https://graphite.rs"
|
||||
repository = "https://github.com/GraphiteEditor/Graphite"
|
||||
license = "Apache-2.0"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
syn = { version = "1.0", features = ["full"] }
|
||||
quote = "1.0"
|
83
node-graph/node-macro/src/lib.rs
Normal file
83
node-graph/node-macro/src/lib.rs
Normal file
|
@ -0,0 +1,83 @@
|
|||
use proc_macro::TokenStream;
|
||||
use quote::{format_ident, ToTokens};
|
||||
use syn::{parse_macro_input, FnArg, Ident, ItemFn, Pat, PatIdent, ReturnType};
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn node_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let node_name = parse_macro_input!(attr as Ident);
|
||||
let function = parse_macro_input!(item as ItemFn);
|
||||
|
||||
let function_name = &function.sig.ident;
|
||||
let mut function_inputs = function.sig.inputs.iter().filter_map(|arg| if let FnArg::Typed(typed_arg) = arg { Some(typed_arg) } else { None });
|
||||
|
||||
// Extract primary input as first argument
|
||||
let primary_input = function_inputs.next().expect("Primary input required - set to `()` if not needed.");
|
||||
let Pat::Ident(PatIdent{ident: primary_input_ident,..} ) =&*primary_input.pat else{
|
||||
panic!("Expected ident as primary input.");
|
||||
};
|
||||
let primary_input_ty = &primary_input.ty;
|
||||
|
||||
// Extract secondary inputs as all other arguments
|
||||
let secondary_inputs = function_inputs.collect::<Vec<_>>();
|
||||
let secondary_idents = secondary_inputs
|
||||
.iter()
|
||||
.map(|input| {
|
||||
let Pat::Ident(PatIdent { ident: primary_input_ident,.. }) = &*input.pat else { panic!("Expected ident for secondary input."); };
|
||||
primary_input_ident
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Extract the output type of the entire node - `()` by default
|
||||
let output = if let ReturnType::Type(_, ty) = &function.sig.output {
|
||||
ty.to_token_stream()
|
||||
} else {
|
||||
quote::quote!(())
|
||||
};
|
||||
|
||||
// Generics are simply `S0` through to `Sn-1` where n is the number of secondary inputs
|
||||
let generics = (0..secondary_inputs.len()).map(|x| format_ident!("S{x}")).collect::<Vec<_>>();
|
||||
// Bindings for all of the above generics to a node with an input of `()` and an output of the type in the function
|
||||
let where_clause = secondary_inputs
|
||||
.iter()
|
||||
.zip(&generics)
|
||||
.map(|(ty, name)| {
|
||||
let ty = &ty.ty;
|
||||
quote::quote!(#name: Node<(), Output = #ty>)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
quote::quote! {
|
||||
#function
|
||||
|
||||
impl <#(#generics),*> Node<#primary_input_ty> for #node_name<#(#generics),*>
|
||||
where
|
||||
#(#where_clause),* {
|
||||
|
||||
type Output = #output;
|
||||
fn eval(self, #primary_input_ident: #primary_input_ty) -> #output{
|
||||
#function_name(#primary_input_ident #(, self.#secondary_idents.eval(()))*)
|
||||
}
|
||||
}
|
||||
|
||||
impl <#(#generics),*> Node<#primary_input_ty> for &#node_name<#(#generics),*>
|
||||
where
|
||||
#(#where_clause + Copy),* {
|
||||
|
||||
type Output = #output;
|
||||
fn eval(self, #primary_input_ident: #primary_input_ty) -> #output{
|
||||
#function_name(#primary_input_ident #(, self.#secondary_idents.eval(()))*)
|
||||
}
|
||||
}
|
||||
|
||||
impl <#(#generics),*> #node_name<#(#generics),*>
|
||||
where
|
||||
#(#where_clause + Copy),* {
|
||||
pub fn new(#(#secondary_idents: #generics),*) -> Self{
|
||||
Self{
|
||||
#(#secondary_idents),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.into()
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
[package]
|
||||
name = "graph-proc-macros"
|
||||
version = "0.1.0"
|
||||
authors = ["Graphite Authors <contact@graphite.rs>"]
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
proc-macro2 = "1.0"
|
||||
proc_macro_roids = "0.7"
|
||||
syn = { version = "1.0", features = ["full"] }
|
||||
quote = "1.0"
|
||||
graphene-core = {path = "../gcore"}
|
|
@ -1,98 +0,0 @@
|
|||
use proc_macro::TokenStream;
|
||||
use proc_macro_roids::*;
|
||||
use quote::{quote, ToTokens};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::{parse_macro_input, FnArg, ItemFn, Pat, Type};
|
||||
|
||||
fn extract_type(a: FnArg) -> Type {
|
||||
match a {
|
||||
FnArg::Typed(p) => *p.ty, // notice `ty` instead of `pat`
|
||||
_ => panic!("Not supported on types with `self`!"),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_arg_types(fn_args: Punctuated<FnArg, syn::token::Comma>) -> Vec<Type> {
|
||||
fn_args.into_iter().map(extract_type).collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn extract_arg_idents(fn_args: Punctuated<FnArg, syn::token::Comma>) -> Vec<Pat> {
|
||||
fn_args.into_iter().map(extract_arg_pat).collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn extract_arg_pat(a: FnArg) -> Pat {
|
||||
match a {
|
||||
FnArg::Typed(p) => *p.pat,
|
||||
_ => panic!("Not supported on types with `self`!"),
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro_attribute] // 2
|
||||
pub fn to_node(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let string = item.to_string();
|
||||
let item2 = item;
|
||||
let parsed = parse_macro_input!(item2 as ItemFn); // 3
|
||||
|
||||
//item.extend(generate_to_string(parsed, string)); // 4
|
||||
//item
|
||||
generate_to_string(parsed, string)
|
||||
}
|
||||
|
||||
fn generate_to_string(parsed: ItemFn, string: String) -> TokenStream {
|
||||
let whole_function = parsed.clone();
|
||||
//let fn_body = parsed.block; // function body
|
||||
let sig = parsed.sig; // function signature
|
||||
|
||||
//let vis = parsed.vis; // visibility, pub or not
|
||||
let generics = sig.generics;
|
||||
let fn_args = sig.inputs; // comma separated args
|
||||
let fn_return_type = sig.output; // return type
|
||||
let fn_name = sig.ident; // function name/identifier
|
||||
let idents = extract_arg_idents(fn_args.clone());
|
||||
let types = extract_arg_types(fn_args);
|
||||
let types = types.iter().map(|t| t.to_token_stream()).collect::<Vec<_>>();
|
||||
let idents = idents.iter().map(|t| t.to_token_stream()).collect::<Vec<_>>();
|
||||
let _const_idents = idents
|
||||
.iter()
|
||||
.map(|t| {
|
||||
let name = t.to_string().to_uppercase();
|
||||
quote! {#name}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let node_fn_name = fn_name.append("_node");
|
||||
let struct_name = fn_name.append("_input");
|
||||
let return_type_string = fn_return_type.to_token_stream().to_string().replace("->", "");
|
||||
let arg_type_string = types.iter().map(|t| t.to_string()).collect::<Vec<_>>().join(", ");
|
||||
let error = format!("called {} with the wrong type", fn_name);
|
||||
|
||||
let x = quote! {
|
||||
//#whole_function
|
||||
mod #fn_name {
|
||||
#[derive(Copy, Clone)]
|
||||
type F32Node<'n> = &'n (dyn Node<'n, (), Output = &'n (dyn Any + 'static)> + 'n);
|
||||
struct #struct_name {
|
||||
#(#idents: #types,)*
|
||||
}
|
||||
impl Node for #struct_name {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
fn #node_fn_name #generics() -> Node<'static> {
|
||||
Node { func: Box::new(move |x| {
|
||||
let args = x.downcast::<(#(#types,)*)>().expect(#error);
|
||||
let (#(#idents,)*) = *args;
|
||||
#whole_function
|
||||
|
||||
Box::new(#fn_name(#(#idents,)*))
|
||||
}),
|
||||
code: #string.to_string(),
|
||||
return_type: #return_type_string.trim().to_string(),
|
||||
args: format!("({})",#arg_type_string.trim()),
|
||||
position: (0., 0.),
|
||||
}
|
||||
}
|
||||
};
|
||||
//panic!("{}\n{:?}", x.to_string(), x);
|
||||
x.into()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue