Update GPU execution and quantization to new node system (#1070)

* Update GPU and quantization to new node system

Squashed commit of the following:

commit 3b69bdafed79f0bb1279609537a8eeead3f06830
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Sun Mar 5 11:37:17 2023 +0100

    Disable dev tools by default

commit dbbbedd68e48d1162442574ad8877c9922d40e4a
Merge: b1018eb5 a8f6e11e
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Sun Mar 5 10:45:00 2023 +0100

    Merge branch 'vite' into tauri-restructure-lite

commit b1018eb5ee56c2d23f9d5a4f034608ec684bd746
Merge: 3195833e 0512cb24
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 17:06:21 2023 +0100

    Merge branch 'master' into tauri-restructure-lite

commit 3195833e4088a4ed7984955c72617b27b7e39bfc
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 17:06:02 2023 +0100

    Bump number of samples

commit 3e57e1e3280759cf4f75726635e31d2b8e9387f9
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 16:55:52 2023 +0100

    Move part of quantization code to gcore

commit 10c15b0bc6ffb51e2bf2d94cd4eb0e24d761fb6f
Merge: 2b3db45a 8fe8896c
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 14:28:56 2023 +0100

    Merge remote-tracking branch 'origin/master' into tauri-restructure-lite

commit 2b3db45aee44a20660f0b1204666bb81e5a7e4b6
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 14:17:11 2023 +0100

    Fix types in node registry

commit 9122f35c0ba9a86255709680d744a48d3c7dcac4
Merge: 26eefc43 2cf4ee0f
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 01:04:55 2023 +0100

    Merge remote-tracking branch 'origin/master' into tauri-restructure-lite

commit 26eefc437eaad873f8d38fdb1fae0a1e3ec189e4
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Mar 2 23:05:53 2023 +0100

    Add Quantize node to document_node_types

commit 3f7606a91329200b2c025010d4a0cffee840a11c
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Mar 2 17:47:51 2023 +0100

    Add quantization nodes to node registry

commit 22d8e477ef79eef5b57b1dc9805e41bbf81cae43
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Mar 2 17:13:28 2023 +0100

    Introduce scopes (#1053)

    * Implement let binding

    * Add lambda inputs

    * Fix tests

    * Fix proto network formatting

    * Generate a template Scoped network by default

    * Add comment to explain the lambda parameter

    * Move binding wrapping out of the template

    * Fix errors cause by image frames

commit 9e0c29d92a164d4a4063e93480e1e289ef5243fe
Author: Alexandru Ică <alexandru@seyhanlee.com>
Date:   Thu Mar 2 15:55:10 2023 +0200

    Make use of ImageFrame in the node system more extensively (#1055) (#1062)

    Make the node system use ImageFrame more extensively (#1055)

commit 5912ef9a1a807917eeb90c1f4835bd8a5de9c821
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Wed Mar 1 16:15:21 2023 +0100

    Split quantization into multiple nodes

commit 285d7b76c176b3e2679ea24eecb38ef867a79f3b
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Mon Feb 27 12:35:57 2023 +0100

    Fix gpu support

commit e0b6327eebba8caf7545c4fedc6670abc4c3652e
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 22:08:53 2023 +0100

    Don't watch frontend files when using tauri

commit 58ae146f6da935cfd37afbd25e1c331b615252da
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 21:48:54 2023 +0100

    Migrate vue code base to vite

commit f996390cc312618a60f98ccb9cd515f1bae5006d
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 19:34:33 2023 +0100

    Start migrating vue to use vite

commit 29d752f47cfd1c74ee51fac6f3d75557a378471c
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 19:00:53 2023 +0100

    Kill cargo watch process automatically

commit 4d1c76b07acadbf609dbab7d57d9a7769b81d4b5
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 17:37:27 2023 +0100

    Start playing around with vite infrastructure

commit 8494f5e9227aa433fd5ca75b268a6a96b2706b36
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 18:40:46 2023 +0000

    Fix import style and eslint rules

commit 92490f7774a7351bb40091bcec78f79c28704768
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 18:25:09 2023 +0000

    Fix icons

commit dc67821abad87f8ff780b12ae96668af2f7bb355
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 18:20:48 2023 +0000

    Add license generator with rollup

commit 441e339d31b76dac4f91321d39a39900b5a79bc1
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 18:14:22 2023 +0000

    Use eslint --fix to fix TS-in-svelte type imports. Now it compiles.

commit 2e847d339e7dcd51ed4c4677ed337c1e20636724
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 17:31:49 2023 +0000

    Remove webpack and plugins

commit 3adab1b7f40ff17b91163e7ca47a403ef3c02fbc
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Mar 2 16:10:19 2023 +0100

    Fix errors cause by image frames

commit 4e5f838995e213b4696225a473b9c56c0084e7a8
Author: Alexandru Ică <alexandru@seyhanlee.com>
Date:   Thu Mar 2 15:55:10 2023 +0200

    Make use of ImageFrame in the node system more extensively (#1055) (#1062)

    Make the node system use ImageFrame more extensively (#1055)

commit 1d4b0e29c693a53c068f1a30f0e857a9c1a59587
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Wed Mar 1 15:13:51 2023 +0100

    Update node graph guide readme with new syntax (#1061)

commit 6735d8c61f5709e22d2b22abd037bab417e868d6
Author: Rob Nadal <Robnadal44@gmail.com>
Date:   Tue Feb 28 18:59:06 2023 -0500

    Bezier-rs: Add function to smoothly join bezier curves (#1037)

    * Added bezier join

    * Stylistic changes per review

commit cd1d7aa7fbcce39fbbf7762d131ee16ad9cb46dd
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Wed Feb 22 23:42:32 2023 +0100

    Implement let binding

    Add lambda inputs

    Fix tests

    Fix proto network formatting

    Generate a template Scoped network by default

    Add comment to explain the lambda parameter

    Move binding wrapping out of the template

* Update package-lock.json

* Regenerate package-lock.json and fix lint errors

* Readd git keep dir

* Revert change to panic.ts

* Fix clippy warnings

* Apply code review

* Clean up node_registry

* Fix test / spriv -> spirv typos
This commit is contained in:
Dennis Kobert 2023-03-05 13:22:14 +01:00 committed by Keavon Chambers
parent 4ea3802df1
commit b55e233fff
38 changed files with 2282 additions and 1924 deletions

820
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -44,5 +44,8 @@ opt-level = 3
[profile.dev.package.graphite-wasm-svelte]
opt-level = 3
[profile.dev.package.autoquant]
opt-level = 3
#[profile.dev]
#opt-level = 3

View file

@ -449,9 +449,9 @@ fn static_nodes() -> Vec<DocumentNodeType> {
DocumentNodeType {
name: "GpuImage",
category: "Image Adjustments",
identifier: NodeImplementation::proto("graphene_std::executor::MapGpuSingleImageNode"),
identifier: NodeImplementation::proto("graphene_std::executor::MapGpuSingleImageNode<_>"),
inputs: vec![
DocumentInputType::new("Image", TaggedValue::ImageFrame(ImageFrame::empty()), true),
DocumentInputType::value("Image", TaggedValue::ImageFrame(ImageFrame::empty()), true),
DocumentInputType {
name: "Path",
data_type: FrontendGraphDataType::Text,
@ -463,9 +463,9 @@ fn static_nodes() -> Vec<DocumentNodeType> {
},
#[cfg(feature = "quantization")]
DocumentNodeType {
name: "QuantizeImage",
category: "Image Adjustments",
identifier: NodeImplementation::proto("graphene_std::quantization::GenerateQuantizationNode"),
name: "Generate Quantization",
category: "Quantization",
identifier: NodeImplementation::proto("graphene_std::quantization::GenerateQuantizationNode<_, _>"),
inputs: vec![
DocumentInputType {
name: "Image",
@ -483,7 +483,47 @@ fn static_nodes() -> Vec<DocumentNodeType> {
default: NodeInput::value(TaggedValue::U32(0), false),
},
],
outputs: vec![DocumentOutputType::new("Image", FrontendGraphDataType::Raster)],
outputs: vec![DocumentOutputType::new("Quantization", FrontendGraphDataType::General)],
properties: node_properties::quantize_properties,
},
#[cfg(feature = "quantization")]
DocumentNodeType {
name: "Quantize Image",
category: "Quantization",
identifier: NodeImplementation::proto("graphene_core::quantization::QuantizeNode<_>"),
inputs: vec![
DocumentInputType {
name: "Image",
data_type: FrontendGraphDataType::Raster,
default: NodeInput::value(TaggedValue::ImageFrame(ImageFrame::empty()), true),
},
DocumentInputType {
name: "Quantization",
data_type: FrontendGraphDataType::General,
default: NodeInput::value(TaggedValue::Quantization(core::array::from_fn(|_| Default::default())), true),
},
],
outputs: vec![DocumentOutputType::new("Encoded", FrontendGraphDataType::Raster)],
properties: node_properties::quantize_properties,
},
#[cfg(feature = "quantization")]
DocumentNodeType {
name: "DeQuantize Image",
category: "Quantization",
identifier: NodeImplementation::proto("graphene_core::quantization::DeQuantizeNode<_>"),
inputs: vec![
DocumentInputType {
name: "Encoded",
data_type: FrontendGraphDataType::Raster,
default: NodeInput::value(TaggedValue::ImageFrame(ImageFrame::empty()), true),
},
DocumentInputType {
name: "Quantization",
data_type: FrontendGraphDataType::General,
default: NodeInput::value(TaggedValue::Quantization(core::array::from_fn(|_| Default::default())), true),
},
],
outputs: vec![DocumentOutputType::new("Decoded", FrontendGraphDataType::Raster)],
properties: node_properties::quantize_properties,
},
DocumentNodeType {

View file

@ -71,23 +71,24 @@ fn start_widgets(document_node: &DocumentNode, node_id: NodeId, index: usize, na
widgets
}
// fn text_widget(document_node: &DocumentNode, node_id: NodeId, index: usize, name: &str, blank_assist: bool) -> Vec<WidgetHolder> {
// let mut widgets = start_widgets(document_node, node_id, index, name, FrontendGraphDataType::Text, blank_assist);
#[cfg(feature = "gpu")]
fn text_widget(document_node: &DocumentNode, node_id: NodeId, index: usize, name: &str, blank_assist: bool) -> Vec<WidgetHolder> {
let mut widgets = start_widgets(document_node, node_id, index, name, FrontendGraphDataType::Text, blank_assist);
// if let NodeInput::Value {
// tagged_value: TaggedValue::String(x),
// exposed: false,
// } = &document_node.inputs[index]
// {
// widgets.extend_from_slice(&[
// WidgetHolder::unrelated_separator(),
// TextInput::new(x.clone())
// .on_update(update_value(|x: &TextInput| TaggedValue::String(x.value.clone()), node_id, index))
// .widget_holder(),
// ])
// }
// widgets
// }
if let NodeInput::Value {
tagged_value: TaggedValue::String(x),
exposed: false,
} = &document_node.inputs[index]
{
widgets.extend_from_slice(&[
WidgetHolder::unrelated_separator(),
TextInput::new(x.clone())
.on_update(update_value(|x: &TextInput| TaggedValue::String(x.value.clone()), node_id, index))
.widget_holder(),
])
}
widgets
}
fn text_area_widget(document_node: &DocumentNode, node_id: NodeId, index: usize, name: &str, blank_assist: bool) -> Vec<WidgetHolder> {
let mut widgets = start_widgets(document_node, node_id, index, name, FrontendGraphDataType::Text, blank_assist);

File diff suppressed because it is too large Load diff

View file

@ -9,8 +9,9 @@
"build": "vue-cli-service build || (npm run print-building-help && exit 1)",
"lint": "vue-cli-service lint || (npm run print-linting-help && exit 1)",
"lint-no-fix": "vue-cli-service lint --no-fix || (npm run print-linting-help && exit 1)",
"tauri:build": "vue-cli-service tauri:build",
"tauri:serve": "vue-cli-service tauri:serve",
"tauri:build": "npm run tauri:build-wasm && vue-cli-service tauri:build",
"tauri:build-wasm": "wasm-pack build wasm --release -- --features tauri",
"tauri:serve": "echo 'Make sure you build the wasm binary for tauri using `npm run tauri:build-wasm`' && npm run serve",
"print-building-help": "echo 'Graphite project failed to build. Did you remember to `npm install` the dependencies in `/frontend`?'",
"print-linting-help": "echo 'Graphite project had lint errors, or may have otherwise failed. In the latter case, did you remember to `npm install` the dependencies in `/frontend`?'"
},

View file

@ -31,6 +31,7 @@ futures = "0.3.25"
[features]
gpu = ["graphite-editor/gpu"]
quantization = ["graphite-editor/quantization"]
# by default Tauri runs in production mode
# when `tauri dev` runs it is executed with `cargo run --no-default-features` if `devPath` is an URL
default = [ "custom-protocol" ]

View file

@ -15,7 +15,6 @@ use http::{Response, StatusCode};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use tauri::Manager;
static IMAGES: Mutex<Option<HashMap<String, FrontendImageData>>> = Mutex::new(None);
static EDITOR: Mutex<Option<Editor>> = Mutex::new(None);
@ -66,8 +65,9 @@ async fn main() {
tauri::Builder::default()
.invoke_handler(tauri::generate_handler![set_random_seed, handle_message])
.setup(|app| {
app.get_window("main").unwrap().open_devtools();
.setup(|_app| {
//use tauri::Manager;
//_app.get_window("main").unwrap().open_devtools();
Ok(())
})
.run(tauri::generate_context!())

View file

@ -117,8 +117,7 @@
import { defineComponent, type PropType } from "vue";
import { debouncer } from "@/components/widgets/debounce";
import type { Widget } from "@/wasm-communication/messages";
import { isWidgetColumn, isWidgetRow, type WidgetColumn, type WidgetRow } from "@/wasm-communication/messages";
import { isWidgetColumn, isWidgetRow, type WidgetColumn, type WidgetRow, type Widget } from "@/wasm-communication/messages";
import PivotAssist from "@/components/widgets/assists/PivotAssist.vue";
import BreadcrumbTrailButtons from "@/components/widgets/buttons/BreadcrumbTrailButtons.vue";

View file

@ -4,8 +4,7 @@ import { type IconName } from "@/utility-functions/icons";
import { browserVersion, operatingSystem } from "@/utility-functions/platform";
import { stripIndents } from "@/utility-functions/strip-indents";
import { type Editor } from "@/wasm-communication/editor";
import type { TextLabel } from "@/wasm-communication/messages";
import { type TextButtonWidget, type WidgetLayout, Widget, DisplayDialogPanic } from "@/wasm-communication/messages";
import { type TextButtonWidget, type TextLabel, type WidgetLayout, Widget, DisplayDialogPanic } from "@/wasm-communication/messages";
export function createPanicManager(editor: Editor, dialogState: DialogState): void {
// Code panic dialog and console error

View file

@ -4,8 +4,7 @@
import { blobToBase64 } from "@/utility-functions/files";
import { type RequestResult, requestWithUploadDownloadProgress } from "@/utility-functions/network";
import { type Editor } from "@/wasm-communication/editor";
import type { XY } from "@/wasm-communication/messages";
import { type ImaginateGenerationParameters } from "@/wasm-communication/messages";
import { type ImaginateGenerationParameters, type XY } from "@/wasm-communication/messages";
const MAX_POLLING_RETRIES = 4;
const SERVER_STATUS_CHECK_TIMEOUT = 5000;

View file

@ -496,7 +496,7 @@ const mouseCursorIconCSSNames = {
Rotate: "custom-rotate",
} as const;
export type MouseCursor = keyof typeof mouseCursorIconCSSNames;
export type MouseCursorIcon = typeof mouseCursorIconCSSNames[MouseCursor];
export type MouseCursorIcon = (typeof mouseCursorIconCSSNames)[MouseCursor];
export class UpdateMouseCursor extends JsMessage {
@Transform(({ value }: { value: MouseCursor }) => mouseCursorIconCSSNames[value] || "alias")
@ -1169,7 +1169,7 @@ const widgetSubTypes = [
{ value: TextLabel, name: "TextLabel" },
{ value: PivotAssist, name: "PivotAssist" },
];
export type WidgetPropsSet = InstanceType<typeof widgetSubTypes[number]["value"]>;
export type WidgetPropsSet = InstanceType<(typeof widgetSubTypes)[number]["value"]>;
export class Widget {
constructor(props: WidgetPropsSet, widgetId: bigint) {

View file

@ -14,7 +14,7 @@ documentation = "https://docs.rs/dyn-any"
[dependencies]
dyn-any-derive = { path = "derive", version = "0.2.0", optional = true }
log = { version = "0.4", optional = true }
glam = { version = "0.22", optional = true }
glam = { version = "0.22", optional = true, default-features = false }
[features]
derive = ["dyn-any-derive"]
@ -25,7 +25,7 @@ rc = []
glam = ["dep:glam"]
alloc = []
large-atomics = []
std = ["alloc", "rc"]
std = ["alloc", "rc", "glam/default"]
default = ["std", "large-atomics"]
[package.metadata.docs.rs]

View file

@ -5,7 +5,7 @@ pub async fn compile<I, O>(network: NodeNetwork) -> Result<Vec<u8>, reqwest::Err
let client = reqwest::Client::new();
let compile_request = CompileRequest::new(network, std::any::type_name::<I>().to_owned(), std::any::type_name::<O>().to_owned());
let response = client.post("http://localhost:3000/compile/spriv").json(&compile_request).send();
let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send();
let response = response.await?;
response.bytes().await.map(|b| b.to_vec())
}

View file

@ -1,6 +1,5 @@
use gpu_compiler_bin_wrapper::CompileRequest;
use graph_craft::concrete;
use graph_craft::document::value::TaggedValue;
use graph_craft::document::*;
use graph_craft::*;
@ -18,13 +17,7 @@ fn main() {
0,
DocumentNode {
name: "Inc Node".into(),
inputs: vec![
NodeInput::Network(concrete!(u32)),
NodeInput::Value {
tagged_value: TaggedValue::U32(1),
exposed: false,
},
],
inputs: vec![NodeInput::Network(concrete!(u32))],
implementation: DocumentNodeImplementation::Network(add_network()),
metadata: DocumentNodeMetadata::default(),
},
@ -34,13 +27,13 @@ fn main() {
};
let compile_request = CompileRequest::new(network, "u32".to_owned(), "u32".to_owned());
let response = client.post("http://localhost:3000/compile/spriv").json(&compile_request).send().unwrap();
let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send().unwrap();
println!("response: {:?}", response);
}
fn add_network() -> NodeNetwork {
NodeNetwork {
inputs: vec![0, 0],
inputs: vec![0],
outputs: vec![NodeOutput::new(1, 0)],
disabled: vec![],
previous_outputs: None,
@ -48,10 +41,10 @@ fn add_network() -> NodeNetwork {
(
0,
DocumentNode {
name: "Cons".into(),
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::Network(concrete!(u32))],
name: "Dup".into(),
inputs: vec![NodeInput::Network(concrete!(u32))],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode")),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::DupNode")),
},
),
(

View file

@ -23,14 +23,14 @@ async fn main() {
let app = Router::new()
.route("/", get(|| async { "Hello from compilation server!" }))
.route("/compile", get(|| async { "Supported targets: spirv" }))
.route("/compile/spriv", post(post_compile_spriv))
.route("/compile/spirv", post(post_compile_spirv))
.with_state(shared_state);
// run it with hyper on localhost:3000
axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
}
async fn post_compile_spriv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> {
async fn post_compile_spirv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> {
let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml";
compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| {
eprintln!("compilation failed: {}", e);

View file

@ -9,20 +9,20 @@ license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
std = ["dyn-any", "dyn-any/std", "alloc"]
std = ["dyn-any", "dyn-any/std", "alloc", "glam/std", "specta"]
default = ["async", "serde", "kurbo", "log", "std"]
log = ["dep:log"]
serde = ["dep:serde", "glam/serde"]
gpu = ["spirv-std", "bytemuck", "glam/bytemuck", "dyn-any"]
gpu = ["spirv-std", "bytemuck", "glam/bytemuck", "dyn-any", "glam/libm"]
async = ["async-trait", "alloc"]
nightly = []
alloc = ["dyn-any", "bezier-rs"]
alloc = ["dyn-any", "bezier-rs", "once_cell"]
type_id_logging = []
[dependencies]
dyn-any = {path = "../../libraries/dyn-any", features = ["derive", "glam"], optional = true, default-features = false }
spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu", features = ["glam"] , optional = true}
spirv-std = { version = "0.5", features = ["glam"] , optional = true}
bytemuck = {version = "1.8", features = ["derive"], optional = true}
async-trait = {version = "0.1", optional = true}
serde = {version = "1.0", features = ["derive"], optional = true, default-features = false }
@ -32,10 +32,11 @@ bezier-rs = { path = "../../libraries/bezier-rs", optional = true }
kurbo = { git = "https://github.com/linebender/kurbo.git", features = [
"serde",
], optional = true }
glam = { version = "^0.22", default-features = false, features = ["scalar-math", "libm"]}
rand_chacha = "0.3.1"
spin = "0.9.2"
glam = { version = "^0.22", default-features = false, features = ["scalar-math"]}
node-macro = {path = "../node-macro"}
specta.workspace = true
once_cell = { version = "1.17.0", default-features = false }
specta.optional = true
once_cell = { version = "1.17.0", default-features = false, optional = true }
# forma = { version = "0.1.0", package = "forma-render" }

View file

@ -11,6 +11,7 @@ pub mod consts;
pub mod generic;
pub mod ops;
pub mod structural;
#[cfg(feature = "std")]
pub mod uuid;
pub mod value;
@ -22,6 +23,8 @@ pub mod raster;
#[cfg(feature = "alloc")]
pub mod vector;
pub mod quantization;
use core::any::TypeId;
pub use raster::Color;
@ -33,6 +36,7 @@ pub trait Node<'i, Input: 'i>: 'i {
#[cfg(feature = "alloc")]
mod types;
#[cfg(feature = "alloc")]
pub use types::*;
pub trait NodeIO<'i, Input: 'i>: 'i + Node<'i, Input>

View file

@ -0,0 +1,83 @@
use crate::raster::Color;
use crate::Node;
use dyn_any::{DynAny, StaticType};
#[derive(Clone, Debug, DynAny, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Quantization {
pub fn_index: usize,
pub a: f32,
pub b: f32,
pub c: f32,
pub d: f32,
}
impl core::hash::Hash for Quantization {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.fn_index.hash(state);
self.a.to_bits().hash(state);
self.b.to_bits().hash(state);
self.c.to_bits().hash(state);
self.d.to_bits().hash(state);
}
}
impl Default for Quantization {
fn default() -> Self {
Self {
fn_index: Default::default(),
a: 1.,
b: Default::default(),
c: Default::default(),
d: Default::default(),
}
}
}
pub type QuantizationChannels = [Quantization; 4];
fn quantize(value: f32, quantization: &Quantization) -> f32 {
let Quantization { fn_index, a, b, c, d } = quantization;
match fn_index {
1 => ((value + a) * d).abs().ln() * b + c,
_ => a * value + b,
}
}
fn decode(value: f32, quantization: &Quantization) -> f32 {
let Quantization { fn_index, a, b, c, d } = quantization;
match fn_index {
1 => -(-c / b).exp() * (a * d * (c / b).exp() - (value / b).exp()) / d,
_ => (value - b) / a,
}
}
pub struct QuantizeNode<Quantization> {
quantization: Quantization,
}
#[node_macro::node_fn(QuantizeNode)]
fn quantize_fn<'a>(color: Color, quantization: [Quantization; 4]) -> Color {
let quant = quantization.as_slice();
let r = quantize(color.r(), &quant[0]);
let g = quantize(color.g(), &quant[1]);
let b = quantize(color.b(), &quant[2]);
let a = quantize(color.a(), &quant[3]);
Color::from_rgbaf32_unchecked(r, g, b, a)
}
pub struct DeQuantizeNode<Quantization> {
quantization: Quantization,
}
#[node_macro::node_fn(DeQuantizeNode)]
fn dequantize_fn<'a>(color: Color, quantization: [Quantization; 4]) -> Color {
let quant = quantization.as_slice();
let r = decode(color.r(), &quant[0]);
let g = decode(color.g(), &quant[1]);
let b = decode(color.b(), &quant[2]);
let a = decode(color.a(), &quant[3]);
Color::from_rgbaf32_unchecked(r, g, b, a)
}

View file

@ -2,6 +2,9 @@ use core::{fmt::Debug, marker::PhantomData};
use crate::Node;
#[cfg(target_arch = "spirv")]
use spirv_std::num_traits::float::Float;
pub mod color;
pub use self::color::Color;
@ -180,6 +183,7 @@ impl<'a> ImageWindowIterator<'a> {
}
}
#[cfg(not(target_arch = "spirv"))]
impl<'a> Iterator for ImageWindowIterator<'a> {
type Item = (Color, (i32, i32));
#[inline]
@ -194,6 +198,9 @@ impl<'a> Iterator for ImageWindowIterator<'a> {
if self.y > max_y {
return None;
}
#[cfg(target_arch = "spirv")]
let value = None;
#[cfg(not(target_arch = "spirv"))]
let value = Some((self.image.data[(self.x + self.y * self.image.width) as usize], (self.x as i32 - start_x, self.y as i32 - start_y)));
self.x += 1;
@ -245,21 +252,49 @@ where
input.for_each(|x| map_node.eval(x));
}
#[cfg(target_arch = "spirv")]
const NOTHING: () = ();
use dyn_any::{DynAny, StaticType};
#[derive(Clone, Debug, PartialEq, DynAny, Default, Copy)]
#[derive(Clone, Debug, PartialEq, DynAny, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct ImageSlice<'a> {
pub width: u32,
pub height: u32,
#[cfg(not(target_arch = "spirv"))]
pub data: &'a [Color],
#[cfg(target_arch = "spirv")]
pub data: &'a (),
}
#[allow(clippy::derivable_impls)]
impl<'a> Default for ImageSlice<'a> {
#[cfg(not(target_arch = "spirv"))]
fn default() -> Self {
Self {
width: Default::default(),
height: Default::default(),
data: Default::default(),
}
}
#[cfg(target_arch = "spirv")]
fn default() -> Self {
Self {
width: Default::default(),
height: Default::default(),
data: &NOTHING,
}
}
}
impl ImageSlice<'_> {
#[cfg(not(target_arch = "spirv"))]
pub const fn empty() -> Self {
Self { width: 0, height: 0, data: &[] }
}
}
#[cfg(not(target_arch = "spirv"))]
impl<'a> IntoIterator for ImageSlice<'a> {
type Item = &'a Color;
type IntoIter = core::slice::Iter<'a, Color>;
@ -268,6 +303,7 @@ impl<'a> IntoIterator for ImageSlice<'a> {
}
}
#[cfg(not(target_arch = "spirv"))]
impl<'a> IntoIterator for &'a ImageSlice<'a> {
type Item = &'a Color;
type IntoIter = core::slice::Iter<'a, Color>;

View file

@ -4,8 +4,12 @@ use crate::Node;
use core::fmt::Debug;
use dyn_any::{DynAny, StaticType};
#[cfg(target_arch = "spirv")]
use spirv_std::num_traits::float::Float;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny, specta::Type, Hash)]
#[cfg_attr(feature = "std", derive(specta::Type))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny, Hash)]
pub enum LuminanceCalculation {
#[default]
SRGB,
@ -27,8 +31,8 @@ impl LuminanceCalculation {
}
}
impl std::fmt::Display for LuminanceCalculation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl core::fmt::Display for LuminanceCalculation {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
LuminanceCalculation::SRGB => write!(f, "sRGB"),
LuminanceCalculation::Perceptual => write!(f, "Perceptual"),
@ -73,7 +77,8 @@ impl BlendMode {
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny, specta::Type, Hash)]
#[cfg_attr(feature = "std", derive(specta::Type))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny, Hash)]
pub enum BlendMode {
#[default]
// Basic group
@ -116,8 +121,8 @@ pub enum BlendMode {
Luminosity,
}
impl std::fmt::Display for BlendMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl core::fmt::Display for BlendMode {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
BlendMode::Normal => write!(f, "Normal"),
@ -368,8 +373,7 @@ fn blend_node(input: (Color, Color), blend_mode: BlendMode, opacity: f64) -> Col
BlendMode::Color => backdrop.blend_color(source_color),
BlendMode::Luminosity => backdrop.blend_luminosity(source_color),
}
.lerp(backdrop, actual_opacity)
.unwrap();
.lerp(backdrop, actual_opacity);
}
#[derive(Debug, Clone, Copy)]

View file

@ -20,7 +20,8 @@ use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "gpu", derive(Pod, Zeroable))]
#[derive(Debug, Clone, Copy, PartialEq, Default, DynAny, specta::Type)]
#[cfg_attr(feature = "std", derive(specta::Type))]
#[derive(Debug, Default, Clone, Copy, PartialEq, DynAny)]
pub struct Color {
red: f32,
green: f32,
@ -30,7 +31,7 @@ pub struct Color {
#[allow(clippy::derived_hash_with_manual_eq)]
impl Hash for Color {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.red.to_bits().hash(state);
self.green.to_bits().hash(state);
self.blue.to_bits().hash(state);
@ -119,7 +120,6 @@ impl Color {
/// use graphene_core::raster::color::Color;
/// let color = Color::from_hsla(0.5, 0.2, 0.3, 1.);
/// ```
#[cfg(not(target_arch = "spirv"))]
pub fn from_hsla(hue: f32, saturation: f32, lightness: f32, alpha: f32) -> Color {
let temp1 = if lightness < 0.5 {
lightness * (saturation + 1.)
@ -127,12 +127,16 @@ impl Color {
lightness + saturation - lightness * saturation
};
let temp2 = 2. * lightness - temp1;
#[cfg(not(target_arch = "spirv"))]
let rem = |x: f32| x.rem_euclid(1.);
#[cfg(target_arch = "spirv")]
let rem = |x: f32| x.rem_euclid(&1.);
let mut red = (hue + 1. / 3.).rem_euclid(1.);
let mut green = hue.rem_euclid(1.);
let mut blue = (hue - 1. / 3.).rem_euclid(1.);
let mut red = rem(hue + 1. / 3.);
let mut green = rem(hue);
let mut blue = rem(hue - 1. / 3.);
for channel in [&mut red, &mut green, &mut blue] {
fn map_channel(channel: &mut f32, temp2: f32, temp1: f32) {
*channel = if *channel * 6. < 1. {
temp2 + (temp1 - temp2) * 6. * *channel
} else if *channel * 2. < 1. {
@ -144,6 +148,9 @@ impl Color {
}
.clamp(0., 1.);
}
map_channel(&mut red, temp2, temp1);
map_channel(&mut green, temp2, temp1);
map_channel(&mut blue, temp2, temp1);
Color { red, green, blue, alpha }
}
@ -427,6 +434,7 @@ impl Color {
/// let color = Color::from_rgba8(0x7C, 0x67, 0xFA, 0x61);
/// assert!("7C67FA61" == color.rgba_hex())
/// ```
#[cfg(feature = "std")]
pub fn rgba_hex(&self) -> String {
format!(
"{:02X?}{:02X?}{:02X?}{:02X?}",
@ -443,6 +451,7 @@ impl Color {
/// let color = Color::from_rgba8(0x7C, 0x67, 0xFA, 0x61);
/// assert!("7C67FA" == color.rgb_hex())
/// ```
#[cfg(feature = "std")]
pub fn rgb_hex(&self) -> String {
format!("{:02X?}{:02X?}{:02X?}", (self.r() * 255.) as u8, (self.g() * 255.) as u8, (self.b() * 255.) as u8,)
}
@ -536,9 +545,9 @@ impl Color {
/// Linearly interpolates between two colors based on t.
///
/// T must be between 0 and 1.
pub fn lerp(self, other: Color, t: f32) -> Option<Self> {
pub fn lerp(self, other: Color, t: f32) -> Self {
assert!((0. ..=1.).contains(&t));
Color::from_rgbaf32(
Color::from_rgbaf32_unchecked(
self.red + ((other.red - self.red) * t),
self.green + ((other.green - self.green) * t),
self.blue + ((other.blue - self.blue) * t),
@ -600,6 +609,7 @@ impl Color {
blue: f(self.blue, other.blue),
alpha: self.alpha,
};
#[cfg(feature = "log")]
if *self == Color::RED {
debug!("{} {} {} {}", color.red, color.green, color.blue, color.alpha);
}

View file

@ -133,8 +133,8 @@ impl Gradient {
((time - self.positions[index].0) / self.positions.get(index + 1).map(|end| end.0 - self.positions[index].0).unwrap_or_default()) as f32,
),
// Use the start or the end colour if applicable
(Some(v), _) | (_, Some(v)) => Some(v),
_ => Some(Color::WHITE),
(Some(v), _) | (_, Some(v)) => v,
_ => Color::WHITE,
};
// Compute the correct index to keep the positions in order
@ -146,7 +146,7 @@ impl Gradient {
let new_color = get_color(index - 1, new_position);
// Insert the new stop
self.positions.insert(index, (new_position, new_color));
self.positions.insert(index, (new_position, Some(new_color)));
Some(index)
}

View file

@ -2,6 +2,12 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "Inflector"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3"
[[package]]
name = "ahash"
version = "0.7.6"
@ -274,6 +280,15 @@ dependencies = [
"crypto-common",
]
[[package]]
name = "document-features"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e493c573fce17f00dcab13b6ac057994f3ce17d1af4dc39bfd482b83c6eb6157"
dependencies = [
"litrs",
]
[[package]]
name = "dyn-any"
version = "0.2.1"
@ -365,7 +380,6 @@ version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12f597d56c1bd55a811a1be189459e8fad2bbc272616375602443bdfb37fa774"
dependencies = [
"num-traits",
"serde",
]
@ -430,6 +444,7 @@ dependencies = [
"num-traits",
"rand_chacha",
"serde",
"specta",
]
[[package]]
@ -443,7 +458,9 @@ dependencies = [
"kurbo",
"log",
"node-macro",
"once_cell",
"serde",
"specta",
]
[[package]]
@ -597,6 +614,12 @@ dependencies = [
"cc",
]
[[package]]
name = "litrs"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9275e0933cf8bb20f008924c0cb07a0692fe54d8064996520bf998de9eb79aa"
[[package]]
name = "log"
version = "0.4.17"
@ -616,6 +639,7 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
name = "node-macro"
version = "0.0.0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
@ -651,9 +675,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.16.0"
version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]]
name = "parse-zoneinfo"
@ -664,6 +688,12 @@ dependencies = [
"regex",
]
[[package]]
name = "paste"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba"
[[package]]
name = "percent-encoding"
version = "2.2.0"
@ -873,14 +903,15 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc_codegen_spirv"
version = "0.4.0"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ea7712b3de402aa159d0073fba3e025836f921847490f3663743eb8d6bf0220"
checksum = "1b803b49618cdde99e1065af9415f489b993374765e30a6b80f2bea2cca65914"
dependencies = [
"ar",
"either",
"hashbrown 0.11.2",
"indexmap",
"lazy_static",
"libc",
"num-traits",
"once_cell",
@ -899,9 +930,9 @@ dependencies = [
[[package]]
name = "rustc_codegen_spirv-types"
version = "0.4.0"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f00bec144cf0240b503dce9fa0dd3bca1e94da3726ddbf2df055bc6155430c1b"
checksum = "330aedc6b09b9bf3c58cc7fb942c1377310a9cff00fae9e4f6cc09a7a28f542e"
dependencies = [
"rspirv",
"serde",
@ -1004,6 +1035,32 @@ dependencies = [
"serde",
]
[[package]]
name = "specta"
version = "0.0.6"
source = "git+https://github.com/oscartbeaumont/rspc?rev=9725ddbfe40183debc055b88c37910eb6f818eae#9725ddbfe40183debc055b88c37910eb6f818eae"
dependencies = [
"document-features",
"glam",
"once_cell",
"paste",
"serde",
"serde_json",
"specta-macros",
]
[[package]]
name = "specta-macros"
version = "0.0.6"
source = "git+https://github.com/oscartbeaumont/rspc?rev=9725ddbfe40183debc055b88c37910eb6f818eae#9725ddbfe40183debc055b88c37910eb6f818eae"
dependencies = [
"Inflector",
"proc-macro2",
"quote",
"syn",
"termcolor",
]
[[package]]
name = "spirt"
version = "0.1.0"
@ -1034,9 +1091,9 @@ dependencies = [
[[package]]
name = "spirv-builder"
version = "0.4.0"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5095b11ee4033cfc146df5bc2ad5671c3c3d900ac46a3c30282bdf387d902250"
checksum = "93f656f97ac742e5603843d2ea3ea644cdf5630635b3320ec87666385766e7ab"
dependencies = [
"memchr",
"raw-string",

View file

@ -24,7 +24,7 @@ base64 = "0.13"
bytemuck = { version = "1.8" }
nvtx = { version = "1.1.1", optional = true }
tempfile = "3"
spirv-builder = { version = "0.4", default-features = false, features=["use-installed-tools"] }
spirv-builder = { version = "0.5", default-features = false, features=["use-installed-tools"] }
tera = { version = "1.17.1" }
anyhow = "1.0.66"
serde_json = "1.0.91"

View file

@ -33,7 +33,7 @@ pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &
if !output.status.success() {
return Err(anyhow::anyhow!("cargo failed: {}", String::from_utf8_lossy(&output.stderr)));
}
Ok(output.stdout)
Ok(std::fs::read(compile_dir.unwrap().to_owned() + "/shader.spv")?)
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]

View file

@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2022-10-29"
channel = "nightly-2022-12-18"
components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"]

View file

@ -30,8 +30,8 @@ pub fn create_files(matadata: &Metadata, network: &ProtoNetwork, compile_dir: &P
let cargo_toml = create_cargo_toml(matadata)?;
std::fs::write(cargo_file, cargo_toml)?;
let toolchain_file = compile_dir.join("rust-toolchain");
let toolchain = include_str!("templates/rust-toolchain");
let toolchain_file = compile_dir.join("rust-toolchain.toml");
let toolchain = include_str!("templates/rust-toolchain.toml");
std::fs::write(toolchain_file, toolchain)?;
// create src dir
@ -69,7 +69,7 @@ pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str
nodes.push(Node {
id,
fqn: fqn.to_string(),
fqn: fqn.to_string().split("<").next().unwrap().to_owned(),
args: node.construction_args.new_function_args(),
});
}

View file

@ -11,7 +11,7 @@ fn main() -> anyhow::Result<()> {
let compile_dir = std::env::args().nth(3).map(|x| std::path::PathBuf::from(&x)).unwrap_or(tempfile::tempdir()?.into_path());
let network: NodeNetwork = serde_json::from_reader(&mut stdin)?;
let compiler = graph_craft::executor::Compiler {};
let proto_network = compiler.compile(network, true);
let proto_network = compiler.compile_single(network, true).unwrap();
dbg!(&compile_dir);
let metadata = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
@ -20,7 +20,9 @@ fn main() -> anyhow::Result<()> {
let result = compiler::compile(&compile_dir)?;
let bytes = std::fs::read(result.module.unwrap_single())?;
stdout.write_all(&bytes)?;
// TODO: properly resolve this
let spirv_path = compile_dir.join("shader.spv");
std::fs::write(&spirv_path, &bytes)?;
Ok(())
}

View file

@ -13,5 +13,5 @@ crate-type = ["dylib", "lib"]
libm = { git = "https://github.com/rust-lang/libm", tag = "0.2.5" }
[dependencies]
spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu" , features= ["glam"]}
spirv-std = { version = "0.5" , features= ["glam"]}
graphene-core = {path = "{{gcore_path}}", default-features = false, features = ["gpu"]}

View file

@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2022-10-29"
channel = "nightly-2022-12-18"
components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"]

View file

@ -17,13 +17,13 @@ pub mod gpu {
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[{{input_type}}],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [{{output_type}}],
#[spirv(push_constant)] push_consts: &graphene_core::gpu::PushConstants,
//#[spirv(push_constant)] push_consts: &graphene_core::gpu::PushConstants,
) {
let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers.
if global_id.x < push_consts.n {
//if global_id.x < push_consts.n {
y[gid] = node_graph(a[gid]);
}
//}
}
fn node_graph(input: {{input_type}}) -> {{output_type}} {

View file

@ -44,6 +44,7 @@ pub enum TaggedValue {
FillType(graphene_core::vector::style::FillType),
GradientType(graphene_core::vector::style::GradientType),
GradientPositions(Vec<(f64, Option<graphene_core::Color>)>),
Quantization(graphene_core::quantization::QuantizationChannels),
}
#[allow(clippy::derived_hash_with_manual_eq)]
@ -175,6 +176,10 @@ impl Hash for TaggedValue {
color.hash(state);
}
}
Self::Quantization(quantized_image) => {
31.hash(state);
quantized_image.hash(state);
}
}
}
}
@ -213,6 +218,7 @@ impl<'a> TaggedValue {
TaggedValue::FillType(x) => Box::new(x),
TaggedValue::GradientType(x) => Box::new(x),
TaggedValue::GradientPositions(x) => Box::new(x),
TaggedValue::Quantization(x) => Box::new(x),
}
}
@ -250,6 +256,7 @@ impl<'a> TaggedValue {
TaggedValue::FillType(_) => concrete!(graphene_core::vector::style::FillType),
TaggedValue::GradientType(_) => concrete!(graphene_core::vector::style::GradientType),
TaggedValue::GradientPositions(_) => concrete!(Vec<(f64, Option<graphene_core::Color>)>),
TaggedValue::Quantization(_) => concrete!(graphene_core::quantization::QuantizationChannels),
}
}
}

View file

@ -1,74 +1,51 @@
use graph_craft::document::*;
use graph_craft::proto::*;
use graphene_core::raster::Image;
use graphene_core::raster::*;
use graphene_core::value::ValueNode;
use graphene_core::Node;
use graphene_core::*;
use bytemuck::Pod;
use core::marker::PhantomData;
use dyn_any::StaticTypeSized;
pub struct MapGpuNode<NN: Node<()>, I: IntoIterator<Item = S>, S: StaticTypeSized + Sync + Send + Pod, O: StaticTypeSized + Sync + Send + Pod>(pub NN, PhantomData<(S, I, O)>);
impl<'n, I: IntoIterator<Item = S>, NN: Node<(), Output = &'n NodeNetwork> + Copy, S: StaticTypeSized + Sync + Send + Pod, O: StaticTypeSized + Sync + Send + Pod> Node<I>
for &MapGpuNode<NN, I, S, O>
{
type Output = Vec<O>;
fn eval(self, input: I) -> Self::Output {
let network = self.0.eval(());
map_gpu_impl(network, input)
}
pub struct MapGpuNode<O, Network> {
network: Network,
_o: PhantomData<O>,
}
fn map_gpu_impl<I: IntoIterator<Item = S>, S: StaticTypeSized + Sync + Send + Pod, O: StaticTypeSized + Sync + Send + Pod>(network: &NodeNetwork, input: I) -> Vec<O> {
#[node_macro::node_fn(MapGpuNode<_O>)]
fn map_gpu<I: IntoIterator<Item = S>, S: StaticTypeSized + Sync + Send + Pod, _O: StaticTypeSized + Sync + Send + Pod>(input: I, network: &'any_input NodeNetwork) -> Vec<_O> {
use graph_craft::executor::Executor;
let bytes = compilation_client::compile_sync::<S, O>(network.clone()).unwrap();
let bytes = compilation_client::compile_sync::<S, _O>(network.clone()).unwrap();
let words = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const u32, bytes.len() / 4) };
use wgpu_executor::{Context, GpuExecutor};
let executor: GpuExecutor<S, O> = GpuExecutor::new(Context::new_sync().unwrap(), words.into(), "gpu::eval".into()).unwrap();
let executor: GpuExecutor<S, _O> = GpuExecutor::new(Context::new_sync().unwrap(), words.into(), "gpu::eval".into()).unwrap();
let data: Vec<_> = input.into_iter().collect();
let result = executor.execute(Box::new(data)).unwrap();
let result = dyn_any::downcast::<Vec<O>>(result).unwrap();
let result = dyn_any::downcast::<Vec<_O>>(result).unwrap();
*result
}
impl<'n, I: IntoIterator<Item = S>, NN: Node<(), Output = &'n NodeNetwork> + Copy, S: StaticTypeSized + Sync + Send + Pod, O: StaticTypeSized + Sync + Send + Pod> Node<I> for MapGpuNode<NN, I, S, O> {
type Output = Vec<O>;
fn eval(self, input: I) -> Self::Output {
let network = self.0.eval(());
map_gpu_impl(network, input)
}
pub struct MapGpuSingleImageNode<N> {
node: N,
}
impl<I: IntoIterator<Item = S>, NN: Node<()>, S: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Sync + Send + Pod> MapGpuNode<NN, I, S, O> {
pub const fn new(network: NN) -> Self {
MapGpuNode(network, PhantomData)
}
}
pub struct MapGpuSingleImageNode<NN: Node<(), Output = String>>(pub NN);
impl<NN: Node<(), Output = String> + Copy> Node<Image> for MapGpuSingleImageNode<NN> {
type Output = Image;
fn eval(self, input: Image) -> Self::Output {
let node = self.0.eval(());
#[node_macro::node_fn(MapGpuSingleImageNode)]
fn map_gpu_single_image(input: Image, node: String) -> Image {
use graph_craft::document::*;
use graph_craft::NodeIdentifier;
let identifier = NodeIdentifier {
name: std::borrow::Cow::Owned(node),
types: std::borrow::Cow::Borrowed(&[]),
};
let identifier = NodeIdentifier { name: std::borrow::Cow::Owned(node) };
let network = NodeNetwork {
inputs: vec![0],
disabled: vec![],
previous_output: None,
output: 0,
previous_outputs: None,
outputs: vec![NodeOutput::new(0, 0)],
nodes: [(
0,
DocumentNode {
name: "Image filter Node".into(),
inputs: vec![NodeInput::Network],
inputs: vec![NodeInput::Network(concrete!(Color))],
implementation: DocumentNodeImplementation::Unresolved(identifier),
metadata: DocumentNodeMetadata::default(),
},
@ -78,44 +55,7 @@ impl<NN: Node<(), Output = String> + Copy> Node<Image> for MapGpuSingleImageNode
};
let value_network = ValueNode::new(network);
let map_node = MapGpuNode::new(&value_network);
let map_node = MapGpuNode::new(value_network);
let data = map_node.eval(input.data.clone());
Image { data, ..input }
}
}
impl<NN: Node<(), Output = String> + Copy> Node<Image> for &MapGpuSingleImageNode<NN> {
type Output = Image;
fn eval(self, input: Image) -> Self::Output {
let node = self.0.eval(());
use graph_craft::document::*;
let identifier = NodeIdentifier {
name: std::borrow::Cow::Owned(node),
types: std::borrow::Cow::Borrowed(&[]),
};
let network = NodeNetwork {
inputs: vec![0],
output: 0,
disabled: vec![],
previous_output: None,
nodes: [(
0,
DocumentNode {
name: "Image filter Node".into(),
inputs: vec![NodeInput::Network],
implementation: DocumentNodeImplementation::Unresolved(identifier),
metadata: DocumentNodeMetadata::default(),
},
)]
.into_iter()
.collect(),
};
let value_network = ValueNode::new(network);
let map_node = MapGpuNode::new(&value_network);
let data = map_node.eval(input.data.clone());
Image { data, ..input }
}
}

View file

@ -1,54 +1,61 @@
use graphene_core::raster::{Color, Image};
use dyn_any::{DynAny, StaticType};
use graphene_core::quantization::*;
use graphene_core::raster::{Color, ImageFrame};
use graphene_core::Node;
/// The `GenerateQuantizationNode` encodes the brightness of each channel of the image as an integer number
/// sepified by the samples parameter. This node is used to asses the loss of visual information when
/// quantizing the image using different fit functions.
pub struct GenerateQuantizationNode<N: Node<(), Output = u32>, M: Node<(), Output = u32>> {
pub struct GenerateQuantizationNode<N, M> {
samples: N,
function: M,
}
#[node_macro::node_fn(GenerateQuantizationNode)]
fn generate_quantization_fn(image: Image, samples: u32, function: u32) -> Image {
// Scale the input image, this can be removed by adding an extra parameter to the fit function.
let max_energy = 16380.;
let data: Vec<f64> = image.data.iter().flat_map(|x| vec![x.r() as f64, x.g() as f64, x.b() as f64]).collect();
let data: Vec<f64> = data.iter().map(|x| x * max_energy).collect();
fn generate_quantization_fn(image_frame: ImageFrame, samples: u32, function: u32) -> [Quantization; 4] {
let image = image_frame.image;
let len = image.data.len().min(10000);
let mut channels: Vec<_> = (0..4).map(|_| Vec::with_capacity(image.data.len())).collect();
image
.data
.iter()
.enumerate()
.filter(|(i, _)| i % (image.data.len() / len) == 0)
.map(|(_, x)| vec![x.r() as f64, x.g() as f64, x.b() as f64, x.a() as f64])
.for_each(|x| x.into_iter().enumerate().for_each(|(i, value)| channels[i].push(value)));
log::info!("Quantizing {} samples", channels[0].len());
log::info!("In {} channels", channels.len());
let quantization: Vec<Quantization> = channels.into_iter().map(|x| generate_quantization_per_channel(x, samples)).collect();
core::array::from_fn(|i| quantization[i].clone())
}
fn generate_quantization_per_channel(data: Vec<f64>, samples: u32) -> Quantization {
let mut dist = autoquant::integrate_distribution(data);
autoquant::drop_duplicates(&mut dist);
let dist = autoquant::normalize_distribution(dist.as_slice());
let max = dist.last().unwrap().0;
let linear = Box::new(autoquant::SimpleFitFn {
/*let linear = Box::new(autoquant::SimpleFitFn {
function: move |x| x / max,
inverse: move |x| x * max,
name: "identity",
});
let best = match function {
0 => linear as Box<dyn autoquant::FitFn>,
1 => linear as Box<dyn autoquant::FitFn>,
2 => Box::new(autoquant::models::OptimizedLog::new(dist, 20)) as Box<dyn autoquant::FitFn>,
_ => linear as Box<dyn autoquant::FitFn>,
});*/
let linear = Quantization {
fn_index: 0,
a: max as f32,
b: 0.,
c: 0.,
d: 0.,
};
let roundtrip = |sample: f32| -> f32 {
let encoded = autoquant::encode(sample as f64 * max_energy, best.as_ref(), samples);
let decoded = autoquant::decode(encoded, best.as_ref(), samples) / max_energy;
log::trace!("{} enc: {} dec: {}", sample, encoded, decoded);
decoded as f32
let log_fit = autoquant::models::OptimizedLog::new(dist, samples as u64);
let parameters = log_fit.parameters();
let log_fit = Quantization {
fn_index: 1,
a: parameters[0] as f32,
b: parameters[1] as f32,
c: parameters[2] as f32,
d: parameters[3] as f32,
};
let new_data = image
.data
.iter()
.map(|c| {
let r = roundtrip(c.r());
let g = roundtrip(c.g());
let b = roundtrip(c.b());
let a = c.a();
Color::from_rgbaf32_unchecked(r, g, b, a)
})
.collect();
Image { data: new_data, ..image }
log_fit
}

View file

@ -158,6 +158,6 @@ mod tests {
let compiler = Compiler {};
let protograph = compiler.compile_single(network, true).expect("Graph should be generated");
let exec = DynamicExecutor::new(protograph).map(|e| panic!("The network should not type check: {:#?}", e)).unwrap_err();
let _exec = DynamicExecutor::new(protograph).map(|e| panic!("The network should not type check: {:#?}", e)).unwrap_err();
}
}

View file

@ -24,6 +24,8 @@ use crate::executor::NodeContainer;
use dyn_any::StaticType;
use graphene_core::quantization::QuantizationChannels;
macro_rules! construct_node {
($args: ident, $path:ty, [$($type:tt),*]) => {{
let mut args: Vec<TypeErasedPinnedRef<'static>> = $args.clone();
@ -133,6 +135,8 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
register_node!(graphene_core::ops::AddParameterNode<_>, input: f64, params: [&f64]),
register_node!(graphene_core::ops::AddParameterNode<_>, input: &f64, params: [&f64]),
register_node!(graphene_core::ops::SomeNode, input: ImageFrame, params: []),
#[cfg(feature = "gpu")]
register_node!(graphene_std::executor::MapGpuSingleImageNode<_>, input: Image, params: [String]),
vec![(
NodeIdentifier::new("graphene_core::structural::ComposeNode<_, _, _>"),
|args| {
@ -292,9 +296,23 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
},
NodeIOTypes::new(concrete!(Image), concrete!(&Image), vec![]),
),
(
NodeIdentifier::new("graphene_std::memo::CacheNode"),
|_| {
let node: CacheNode<QuantizationChannels> = graphene_std::memo::CacheNode::new();
let any = DynAnyRefNode::new(node);
any.into_type_erased()
},
NodeIOTypes::new(concrete!(QuantizationChannels), concrete!(&QuantizationChannels), vec![]),
),
],
register_node!(graphene_core::structural::ConsNode<_, _>, input: Image, params: [&str]),
register_node!(graphene_std::raster::ImageFrameNode<_>, input: Image, params: [DAffine2]),
#[cfg(feature = "quantization")]
register_node!(graphene_std::quantization::GenerateQuantizationNode<_, _>, input: ImageFrame, params: [u32, u32]),
raster_node!(graphene_core::quantization::QuantizeNode<_>, params: [QuantizationChannels]),
raster_node!(graphene_core::quantization::DeQuantizeNode<_>, params: [QuantizationChannels]),
register_node!(graphene_core::ops::CloneNode<_>, input: &QuantizationChannels, params: []),
register_node!(graphene_core::vector::TransformNode<_, _, _, _>, input: VectorData, params: [DVec2, f64, DVec2, DVec2]),
register_node!(graphene_core::vector::SetFillNode<_, _, _, _, _, _, _>, input: VectorData, params: [ graphene_core::vector::style::FillType, graphene_core::Color, graphene_core::vector::style::GradientType, DVec2, DVec2, DAffine2, Vec<(f64, Option<graphene_core::Color>)>]),
register_node!(graphene_core::vector::SetStrokeNode<_, _, _, _, _, _, _>, input: VectorData, params: [graphene_core::Color, f64, Vec<f32>, f64, graphene_core::vector::style::LineCap, graphene_core::vector::style::LineJoin, f64]),
@ -304,158 +322,6 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
input: graphene_core::vector::bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>,
params: []
),
/*
(NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]), |_proto_node, stack| {
stack.push_fn(|_nodes| {
let image = FnNode::new(|s: &str| graphene_std::raster::image_node::<&str>().eval(s).unwrap());
let node: DynAnyNode<_, &str, _, _> = DynAnyNode::new(image);
node.into_type_erased()
})
}),
(NodeIdentifier::new("graphene_std::raster::ExportImageNode", &[concrete!("&str")]), |proto_node, stack| {
stack.push_fn(|nodes| {
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
let image = FnNode::new(|input: (Image, &str)| graphene_std::raster::export_image_node().eval(input).unwrap());
let node: DynAnyNode<_, (Image, &str), _, _> = DynAnyNode::new(image);
let node = (pre_node).then(node);
node.into_type_erased()
})
}),
(
NodeIdentifier::new("graphene_core::structural::ConsNode", &[concrete!("Image"), concrete!("&str")]),
|proto_node, stack| {
let node_id = proto_node.input.unwrap_node() as usize;
if let ConstructionArgs::Nodes(cons_node_arg) = proto_node.construction_args {
stack.push_fn(move |nodes| {
let pre_node = nodes.get(node_id).unwrap();
let cons_node_arg = nodes.get(cons_node_arg[0] as usize).unwrap();
let cons_node = ConsNode::new(DowncastNode::<_, &str>::new(cons_node_arg));
let node: DynAnyNode<_, Image, _, _> = DynAnyNode::new(cons_node);
let node = (pre_node).then(node);
node.into_type_erased()
})
} else {
unimplemented!()
}
},
),
(NodeIdentifier::new("graphene_std::vector::generator_nodes::UnitCircleGenerator", &[]), |_proto_node, stack| {
stack.push_fn(|_nodes| DynAnyNode::new(graphene_std::vector::generator_nodes::UnitCircleGenerator).into_type_erased())
}),
(NodeIdentifier::new("graphene_std::vector::generator_nodes::UnitSquareGenerator", &[]), |_proto_node, stack| {
stack.push_fn(|_nodes| DynAnyNode::new(graphene_std::vector::generator_nodes::UnitSquareGenerator).into_type_erased())
}),
(NodeIdentifier::new("graphene_std::vector::generator_nodes::BlitSubpath", &[]), |proto_node, stack| {
stack.push_fn(move |nodes| {
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("BlitSubpath without subpath input") };
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
let input_node: DowncastBothNode<_, (), Subpath> = DowncastBothNode::new(value_node);
let node = DynAnyNode::new(graphene_std::vector::generator_nodes::BlitSubpath::new(input_node));
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(node).into_type_erased()
} else {
node.into_type_erased()
}
})
}),
(NodeIdentifier::new("graphene_std::vector::generator_nodes::TransformSubpathNode", &[]), |proto_node, stack| {
stack.push_fn(move |nodes| {
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("TransformSubpathNode without subpath input") };
let translate_node: DowncastBothNode<_, (), DVec2> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
let rotate_node: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[1] as usize).unwrap());
let scale_node: DowncastBothNode<_, (), DVec2> = DowncastBothNode::new(nodes.get(construction_nodes[2] as usize).unwrap());
let shear_node: DowncastBothNode<_, (), DVec2> = DowncastBothNode::new(nodes.get(construction_nodes[3] as usize).unwrap());
let node = DynAnyNode::new(graphene_std::vector::generator_nodes::TransformSubpathNode::new(translate_node, rotate_node, scale_node, shear_node));
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(node).into_type_erased()
} else {
node.into_type_erased()
}
})
}),
#[cfg(feature = "gpu")]
(
NodeIdentifier::new("graphene_std::executor::MapGpuNode", &[concrete!("&TypeErasedNode"), concrete!("Color"), concrete!("Color")]),
|proto_node, stack| {
if let ConstructionArgs::Nodes(operation_node_id) = proto_node.construction_args {
stack.push_fn(move |nodes| {
info!("Map image Depending upon id {:?}", operation_node_id);
let operation_node = nodes.get(operation_node_id[0] as usize).unwrap();
let input_node: DowncastBothNode<_, (), &graph_craft::document::NodeNetwork> = DowncastBothNode::new(operation_node);
let map_node: graphene_std::executor::MapGpuNode<_, Vec<u32>, u32, u32> = graphene_std::executor::MapGpuNode::new(input_node);
let map_node = DynAnyNode::new(map_node);
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(map_node).into_type_erased()
} else {
map_node.into_type_erased()
}
})
} else {
unimplemented!()
}
},
),
#[cfg(feature = "gpu")]
(
NodeIdentifier::new("graphene_std::executor::MapGpuSingleImageNode", &[concrete!("&TypeErasedNode")]),
|proto_node, stack| {
if let ConstructionArgs::Nodes(operation_node_id) = proto_node.construction_args {
stack.push_fn(move |nodes| {
info!("Map image Depending upon id {:?}", operation_node_id);
let operation_node = nodes.get(operation_node_id[0] as usize).unwrap();
let input_node: DowncastBothNode<_, (), String> = DowncastBothNode::new(operation_node);
let map_node = graphene_std::executor::MapGpuSingleImageNode(input_node);
let map_node = DynAnyNode::new(map_node);
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(map_node).into_type_erased()
} else {
map_node.into_type_erased()
}
})
} else {
unimplemented!()
}
},
),
#[cfg(feature = "quantization")]
(
NodeIdentifier::new("graphene_std::quantization::GenerateQuantizationNode", &[concrete!("&TypeErasedNode")]),
|proto_node, stack| {
if let ConstructionArgs::Nodes(operation_node_id) = proto_node.construction_args {
stack.push_fn(move |nodes| {
info!("Quantization Depending upon id {:?}", operation_node_id);
let samples_node = nodes.get(operation_node_id[0] as usize).unwrap();
let index_node = nodes.get(operation_node_id[1] as usize).unwrap();
let samples_node: DowncastBothNode<_, (), u32> = DowncastBothNode::new(samples_node);
let index_node: DowncastBothNode<_, (), u32> = DowncastBothNode::new(index_node);
let map_node = graphene_std::quantization::GenerateQuantizationNode::new(samples_node, index_node);
let map_node = DynAnyNode::new(map_node);
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(map_node).into_type_erased()
} else {
map_node.into_type_erased()
}
})
} else {
unimplemented!()
}
},
),
<<<<<<< HEAD
*/
];
let mut map: HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>> = HashMap::new();
for (id, c, types) in node_types.into_iter().flatten() {
@ -466,101 +332,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
pub static NODE_REGISTRY: Lazy<HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>>> = Lazy::new(|| node_registry());
/*
#[cfg(test)]
mod protograph_testing {
use borrow_stack::BorrowStack;
use super::*;
#[test]
fn add_values() {
let stack = FixedSizeStack::new(256);
let val_1_protonode = ProtoNode::value(ConstructionArgs::Value(Box::new(2u32)));
constrcut_node(val_1_protonode, &stack);
let val_2_protonode = ProtoNode::value(ConstructionArgs::Value(Box::new(40u32)));
constrcut_node(val_2_protonode, &stack);
let cons_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![1]),
input: ProtoNodeInput::Node(0),
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[concrete!("u32"), concrete!("u32")]),
};
constrcut_node(cons_protonode, &stack);
let add_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::Node(2),
identifier: NodeIdentifier::new("graphene_core::ops::AddNode", &[concrete!("u32"), concrete!("u32")]),
};
constrcut_node(add_protonode, &stack);
let result = unsafe { stack.get()[3].eval(Box::new(())) };
let val = *dyn_any::downcast::<u32>(result).unwrap();
assert_eq!(val, 42);
// TODO: adde tests testing the node registry
}
#[test]
fn grayscale_color() {
let stack = FixedSizeStack::new(256);
let val_protonode = ProtoNode::value(ConstructionArgs::Value(Box::new(Color::from_rgb8(10, 20, 30))));
constrcut_node(val_protonode, &stack);
let grayscale_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::Node(0),
identifier: NodeIdentifier::new("graphene_core::raster::GrayscaleColorNode", &[]),
};
constrcut_node(grayscale_protonode, &stack);
let result = unsafe { stack.get()[1].eval(Box::new(())) };
let val = *dyn_any::downcast::<Color>(result).unwrap();
assert_eq!(val, Color::from_rgb8(20, 20, 20));
}
#[test]
fn load_image() {
let stack = FixedSizeStack::new(256);
let image_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::None,
identifier: NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]),
};
constrcut_node(image_protonode, &stack);
let result = unsafe { stack.get()[0].eval(Box::new("../gstd/test-image-1.png")) };
let image = *dyn_any::downcast::<Image>(result).unwrap();
assert_eq!(image.height, 240);
}
#[test]
fn grayscale_map_image() {
let stack = FixedSizeStack::new(256);
let image_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::None,
identifier: NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]),
};
constrcut_node(image_protonode, &stack);
let grayscale_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::None,
identifier: NodeIdentifier::new("graphene_core::raster::GrayscaleColorNode", &[]),
};
constrcut_node(grayscale_protonode, &stack);
let image_map_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![1]),
input: ProtoNodeInput::Node(0),
identifier: NodeIdentifier::new("graphene_std::raster::MapImageNode", &[]),
};
constrcut_node(image_map_protonode, &stack);
let result = unsafe { stack.get()[2].eval(Box::new("../gstd/test-image-1.png")) };
let image = *dyn_any::downcast::<Image>(result).unwrap();
assert!(!image.data.iter().any(|c| c.r() != c.b() || c.b() != c.g()));
}
}
*/

View file

@ -17,6 +17,7 @@ pub struct GpuExecutor<'a, I: StaticTypeSized, O> {
impl<'a, I: StaticTypeSized, O> GpuExecutor<'a, I, O> {
pub fn new(context: Context, shader: Cow<'a, [u32]>, entry_point: String) -> anyhow::Result<Self> {
log::info!("Creating executor");
Ok(Self {
context,
entry_point,
@ -28,6 +29,7 @@ impl<'a, I: StaticTypeSized, O> GpuExecutor<'a, I, O> {
impl<'a, I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Sync + Pod> Executor for GpuExecutor<'a, I, O> {
fn execute<'i, 's: 'i>(&'s self, input: Any<'i>) -> Result<Any<'i>, Box<dyn std::error::Error>> {
log::info!("Executing shader");
let input = dyn_any::downcast::<Vec<I>>(input).expect("Wrong input type");
let context = &self.context;
let future = execute_shader(context.device.clone(), context.queue.clone(), self.shader.to_vec(), *input, self.entry_point.clone());
@ -40,6 +42,11 @@ impl<'a, I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Syn
async fn execute_shader<I: Pod + Send + Sync, O: Pod + Send + Sync>(device: Arc<wgpu::Device>, queue: Arc<wgpu::Queue>, shader: Vec<u32>, data: Vec<I>, entry_point: String) -> Option<Vec<O>> {
// Loads the shader from WGSL
dbg!(&shader);
//write shader to file
use std::io::Write;
let mut file = std::fs::File::create("/tmp/shader.spv").unwrap();
file.write_all(bytemuck::cast_slice(&shader)).unwrap();
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::SpirV(shader.into()),
@ -122,7 +129,7 @@ async fn execute_shader<I: Pod + Send + Sync, O: Pod + Send + Sync>(device: Arc<
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.insert_debug_marker("compute node network evaluation");
cpass.dispatch_workgroups(data.len() as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
cpass.dispatch_workgroups(data.len().min(65535) as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
}
// Sets adds copy operation to command encoder.
// Will copy data from storage buffer on GPU to staging buffer on CPU.