Update GPU execution and quantization to new node system (#1070)

* Update GPU and quantization to new node system

Squashed commit of the following:

commit 3b69bdafed79f0bb1279609537a8eeead3f06830
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Sun Mar 5 11:37:17 2023 +0100

    Disable dev tools by default

commit dbbbedd68e48d1162442574ad8877c9922d40e4a
Merge: b1018eb5 a8f6e11e
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Sun Mar 5 10:45:00 2023 +0100

    Merge branch 'vite' into tauri-restructure-lite

commit b1018eb5ee56c2d23f9d5a4f034608ec684bd746
Merge: 3195833e 0512cb24
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 17:06:21 2023 +0100

    Merge branch 'master' into tauri-restructure-lite

commit 3195833e4088a4ed7984955c72617b27b7e39bfc
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 17:06:02 2023 +0100

    Bump number of samples

commit 3e57e1e3280759cf4f75726635e31d2b8e9387f9
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 16:55:52 2023 +0100

    Move part of quantization code to gcore

commit 10c15b0bc6ffb51e2bf2d94cd4eb0e24d761fb6f
Merge: 2b3db45a 8fe8896c
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 14:28:56 2023 +0100

    Merge remote-tracking branch 'origin/master' into tauri-restructure-lite

commit 2b3db45aee44a20660f0b1204666bb81e5a7e4b6
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 14:17:11 2023 +0100

    Fix types in node registry

commit 9122f35c0ba9a86255709680d744a48d3c7dcac4
Merge: 26eefc43 2cf4ee0f
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Fri Mar 3 01:04:55 2023 +0100

    Merge remote-tracking branch 'origin/master' into tauri-restructure-lite

commit 26eefc437eaad873f8d38fdb1fae0a1e3ec189e4
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Mar 2 23:05:53 2023 +0100

    Add Quantize node to document_node_types

commit 3f7606a91329200b2c025010d4a0cffee840a11c
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Mar 2 17:47:51 2023 +0100

    Add quantization nodes to node registry

commit 22d8e477ef79eef5b57b1dc9805e41bbf81cae43
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Mar 2 17:13:28 2023 +0100

    Introduce scopes (#1053)

    * Implement let binding

    * Add lambda inputs

    * Fix tests

    * Fix proto network formatting

    * Generate a template Scoped network by default

    * Add comment to explain the lambda parameter

    * Move binding wrapping out of the template

    * Fix errors cause by image frames

commit 9e0c29d92a164d4a4063e93480e1e289ef5243fe
Author: Alexandru Ică <alexandru@seyhanlee.com>
Date:   Thu Mar 2 15:55:10 2023 +0200

    Make use of ImageFrame in the node system more extensively (#1055) (#1062)

    Make the node system use ImageFrame more extensively (#1055)

commit 5912ef9a1a807917eeb90c1f4835bd8a5de9c821
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Wed Mar 1 16:15:21 2023 +0100

    Split quantization into multiple nodes

commit 285d7b76c176b3e2679ea24eecb38ef867a79f3b
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Mon Feb 27 12:35:57 2023 +0100

    Fix gpu support

commit e0b6327eebba8caf7545c4fedc6670abc4c3652e
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 22:08:53 2023 +0100

    Don't watch frontend files when using tauri

commit 58ae146f6da935cfd37afbd25e1c331b615252da
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 21:48:54 2023 +0100

    Migrate vue code base to vite

commit f996390cc312618a60f98ccb9cd515f1bae5006d
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 19:34:33 2023 +0100

    Start migrating vue to use vite

commit 29d752f47cfd1c74ee51fac6f3d75557a378471c
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 19:00:53 2023 +0100

    Kill cargo watch process automatically

commit 4d1c76b07acadbf609dbab7d57d9a7769b81d4b5
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Feb 16 17:37:27 2023 +0100

    Start playing around with vite infrastructure

commit 8494f5e9227aa433fd5ca75b268a6a96b2706b36
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 18:40:46 2023 +0000

    Fix import style and eslint rules

commit 92490f7774a7351bb40091bcec78f79c28704768
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 18:25:09 2023 +0000

    Fix icons

commit dc67821abad87f8ff780b12ae96668af2f7bb355
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 18:20:48 2023 +0000

    Add license generator with rollup

commit 441e339d31b76dac4f91321d39a39900b5a79bc1
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 18:14:22 2023 +0000

    Use eslint --fix to fix TS-in-svelte type imports. Now it compiles.

commit 2e847d339e7dcd51ed4c4677ed337c1e20636724
Author: Locria Cyber <74560659+locriacyber@users.noreply.github.com>
Date:   Thu Jan 19 17:31:49 2023 +0000

    Remove webpack and plugins

commit 3adab1b7f40ff17b91163e7ca47a403ef3c02fbc
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Thu Mar 2 16:10:19 2023 +0100

    Fix errors cause by image frames

commit 4e5f838995e213b4696225a473b9c56c0084e7a8
Author: Alexandru Ică <alexandru@seyhanlee.com>
Date:   Thu Mar 2 15:55:10 2023 +0200

    Make use of ImageFrame in the node system more extensively (#1055) (#1062)

    Make the node system use ImageFrame more extensively (#1055)

commit 1d4b0e29c693a53c068f1a30f0e857a9c1a59587
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Wed Mar 1 15:13:51 2023 +0100

    Update node graph guide readme with new syntax (#1061)

commit 6735d8c61f5709e22d2b22abd037bab417e868d6
Author: Rob Nadal <Robnadal44@gmail.com>
Date:   Tue Feb 28 18:59:06 2023 -0500

    Bezier-rs: Add function to smoothly join bezier curves (#1037)

    * Added bezier join

    * Stylistic changes per review

commit cd1d7aa7fbcce39fbbf7762d131ee16ad9cb46dd
Author: Dennis Kobert <dennis@kobert.dev>
Date:   Wed Feb 22 23:42:32 2023 +0100

    Implement let binding

    Add lambda inputs

    Fix tests

    Fix proto network formatting

    Generate a template Scoped network by default

    Add comment to explain the lambda parameter

    Move binding wrapping out of the template

* Update package-lock.json

* Regenerate package-lock.json and fix lint errors

* Readd git keep dir

* Revert change to panic.ts

* Fix clippy warnings

* Apply code review

* Clean up node_registry

* Fix test / spriv -> spirv typos
This commit is contained in:
Dennis Kobert 2023-03-05 13:22:14 +01:00 committed by Keavon Chambers
parent 4ea3802df1
commit b55e233fff
38 changed files with 2282 additions and 1924 deletions

820
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -44,5 +44,8 @@ opt-level = 3
[profile.dev.package.graphite-wasm-svelte] [profile.dev.package.graphite-wasm-svelte]
opt-level = 3 opt-level = 3
[profile.dev.package.autoquant]
opt-level = 3
#[profile.dev] #[profile.dev]
#opt-level = 3 #opt-level = 3

View file

@ -449,9 +449,9 @@ fn static_nodes() -> Vec<DocumentNodeType> {
DocumentNodeType { DocumentNodeType {
name: "GpuImage", name: "GpuImage",
category: "Image Adjustments", category: "Image Adjustments",
identifier: NodeImplementation::proto("graphene_std::executor::MapGpuSingleImageNode"), identifier: NodeImplementation::proto("graphene_std::executor::MapGpuSingleImageNode<_>"),
inputs: vec![ inputs: vec![
DocumentInputType::new("Image", TaggedValue::ImageFrame(ImageFrame::empty()), true), DocumentInputType::value("Image", TaggedValue::ImageFrame(ImageFrame::empty()), true),
DocumentInputType { DocumentInputType {
name: "Path", name: "Path",
data_type: FrontendGraphDataType::Text, data_type: FrontendGraphDataType::Text,
@ -463,9 +463,9 @@ fn static_nodes() -> Vec<DocumentNodeType> {
}, },
#[cfg(feature = "quantization")] #[cfg(feature = "quantization")]
DocumentNodeType { DocumentNodeType {
name: "QuantizeImage", name: "Generate Quantization",
category: "Image Adjustments", category: "Quantization",
identifier: NodeImplementation::proto("graphene_std::quantization::GenerateQuantizationNode"), identifier: NodeImplementation::proto("graphene_std::quantization::GenerateQuantizationNode<_, _>"),
inputs: vec![ inputs: vec![
DocumentInputType { DocumentInputType {
name: "Image", name: "Image",
@ -483,7 +483,47 @@ fn static_nodes() -> Vec<DocumentNodeType> {
default: NodeInput::value(TaggedValue::U32(0), false), default: NodeInput::value(TaggedValue::U32(0), false),
}, },
], ],
outputs: vec![DocumentOutputType::new("Image", FrontendGraphDataType::Raster)], outputs: vec![DocumentOutputType::new("Quantization", FrontendGraphDataType::General)],
properties: node_properties::quantize_properties,
},
#[cfg(feature = "quantization")]
DocumentNodeType {
name: "Quantize Image",
category: "Quantization",
identifier: NodeImplementation::proto("graphene_core::quantization::QuantizeNode<_>"),
inputs: vec![
DocumentInputType {
name: "Image",
data_type: FrontendGraphDataType::Raster,
default: NodeInput::value(TaggedValue::ImageFrame(ImageFrame::empty()), true),
},
DocumentInputType {
name: "Quantization",
data_type: FrontendGraphDataType::General,
default: NodeInput::value(TaggedValue::Quantization(core::array::from_fn(|_| Default::default())), true),
},
],
outputs: vec![DocumentOutputType::new("Encoded", FrontendGraphDataType::Raster)],
properties: node_properties::quantize_properties,
},
#[cfg(feature = "quantization")]
DocumentNodeType {
name: "DeQuantize Image",
category: "Quantization",
identifier: NodeImplementation::proto("graphene_core::quantization::DeQuantizeNode<_>"),
inputs: vec![
DocumentInputType {
name: "Encoded",
data_type: FrontendGraphDataType::Raster,
default: NodeInput::value(TaggedValue::ImageFrame(ImageFrame::empty()), true),
},
DocumentInputType {
name: "Quantization",
data_type: FrontendGraphDataType::General,
default: NodeInput::value(TaggedValue::Quantization(core::array::from_fn(|_| Default::default())), true),
},
],
outputs: vec![DocumentOutputType::new("Decoded", FrontendGraphDataType::Raster)],
properties: node_properties::quantize_properties, properties: node_properties::quantize_properties,
}, },
DocumentNodeType { DocumentNodeType {

View file

@ -71,23 +71,24 @@ fn start_widgets(document_node: &DocumentNode, node_id: NodeId, index: usize, na
widgets widgets
} }
// fn text_widget(document_node: &DocumentNode, node_id: NodeId, index: usize, name: &str, blank_assist: bool) -> Vec<WidgetHolder> { #[cfg(feature = "gpu")]
// let mut widgets = start_widgets(document_node, node_id, index, name, FrontendGraphDataType::Text, blank_assist); fn text_widget(document_node: &DocumentNode, node_id: NodeId, index: usize, name: &str, blank_assist: bool) -> Vec<WidgetHolder> {
let mut widgets = start_widgets(document_node, node_id, index, name, FrontendGraphDataType::Text, blank_assist);
// if let NodeInput::Value { if let NodeInput::Value {
// tagged_value: TaggedValue::String(x), tagged_value: TaggedValue::String(x),
// exposed: false, exposed: false,
// } = &document_node.inputs[index] } = &document_node.inputs[index]
// { {
// widgets.extend_from_slice(&[ widgets.extend_from_slice(&[
// WidgetHolder::unrelated_separator(), WidgetHolder::unrelated_separator(),
// TextInput::new(x.clone()) TextInput::new(x.clone())
// .on_update(update_value(|x: &TextInput| TaggedValue::String(x.value.clone()), node_id, index)) .on_update(update_value(|x: &TextInput| TaggedValue::String(x.value.clone()), node_id, index))
// .widget_holder(), .widget_holder(),
// ]) ])
// } }
// widgets widgets
// } }
fn text_area_widget(document_node: &DocumentNode, node_id: NodeId, index: usize, name: &str, blank_assist: bool) -> Vec<WidgetHolder> { fn text_area_widget(document_node: &DocumentNode, node_id: NodeId, index: usize, name: &str, blank_assist: bool) -> Vec<WidgetHolder> {
let mut widgets = start_widgets(document_node, node_id, index, name, FrontendGraphDataType::Text, blank_assist); let mut widgets = start_widgets(document_node, node_id, index, name, FrontendGraphDataType::Text, blank_assist);

File diff suppressed because it is too large Load diff

View file

@ -9,8 +9,9 @@
"build": "vue-cli-service build || (npm run print-building-help && exit 1)", "build": "vue-cli-service build || (npm run print-building-help && exit 1)",
"lint": "vue-cli-service lint || (npm run print-linting-help && exit 1)", "lint": "vue-cli-service lint || (npm run print-linting-help && exit 1)",
"lint-no-fix": "vue-cli-service lint --no-fix || (npm run print-linting-help && exit 1)", "lint-no-fix": "vue-cli-service lint --no-fix || (npm run print-linting-help && exit 1)",
"tauri:build": "vue-cli-service tauri:build", "tauri:build": "npm run tauri:build-wasm && vue-cli-service tauri:build",
"tauri:serve": "vue-cli-service tauri:serve", "tauri:build-wasm": "wasm-pack build wasm --release -- --features tauri",
"tauri:serve": "echo 'Make sure you build the wasm binary for tauri using `npm run tauri:build-wasm`' && npm run serve",
"print-building-help": "echo 'Graphite project failed to build. Did you remember to `npm install` the dependencies in `/frontend`?'", "print-building-help": "echo 'Graphite project failed to build. Did you remember to `npm install` the dependencies in `/frontend`?'",
"print-linting-help": "echo 'Graphite project had lint errors, or may have otherwise failed. In the latter case, did you remember to `npm install` the dependencies in `/frontend`?'" "print-linting-help": "echo 'Graphite project had lint errors, or may have otherwise failed. In the latter case, did you remember to `npm install` the dependencies in `/frontend`?'"
}, },

View file

@ -31,6 +31,7 @@ futures = "0.3.25"
[features] [features]
gpu = ["graphite-editor/gpu"] gpu = ["graphite-editor/gpu"]
quantization = ["graphite-editor/quantization"]
# by default Tauri runs in production mode # by default Tauri runs in production mode
# when `tauri dev` runs it is executed with `cargo run --no-default-features` if `devPath` is an URL # when `tauri dev` runs it is executed with `cargo run --no-default-features` if `devPath` is an URL
default = [ "custom-protocol" ] default = [ "custom-protocol" ]

View file

@ -15,7 +15,6 @@ use http::{Response, StatusCode};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use tauri::Manager;
static IMAGES: Mutex<Option<HashMap<String, FrontendImageData>>> = Mutex::new(None); static IMAGES: Mutex<Option<HashMap<String, FrontendImageData>>> = Mutex::new(None);
static EDITOR: Mutex<Option<Editor>> = Mutex::new(None); static EDITOR: Mutex<Option<Editor>> = Mutex::new(None);
@ -66,8 +65,9 @@ async fn main() {
tauri::Builder::default() tauri::Builder::default()
.invoke_handler(tauri::generate_handler![set_random_seed, handle_message]) .invoke_handler(tauri::generate_handler![set_random_seed, handle_message])
.setup(|app| { .setup(|_app| {
app.get_window("main").unwrap().open_devtools(); //use tauri::Manager;
//_app.get_window("main").unwrap().open_devtools();
Ok(()) Ok(())
}) })
.run(tauri::generate_context!()) .run(tauri::generate_context!())

View file

@ -117,8 +117,7 @@
import { defineComponent, type PropType } from "vue"; import { defineComponent, type PropType } from "vue";
import { debouncer } from "@/components/widgets/debounce"; import { debouncer } from "@/components/widgets/debounce";
import type { Widget } from "@/wasm-communication/messages"; import { isWidgetColumn, isWidgetRow, type WidgetColumn, type WidgetRow, type Widget } from "@/wasm-communication/messages";
import { isWidgetColumn, isWidgetRow, type WidgetColumn, type WidgetRow } from "@/wasm-communication/messages";
import PivotAssist from "@/components/widgets/assists/PivotAssist.vue"; import PivotAssist from "@/components/widgets/assists/PivotAssist.vue";
import BreadcrumbTrailButtons from "@/components/widgets/buttons/BreadcrumbTrailButtons.vue"; import BreadcrumbTrailButtons from "@/components/widgets/buttons/BreadcrumbTrailButtons.vue";

View file

@ -4,8 +4,7 @@ import { type IconName } from "@/utility-functions/icons";
import { browserVersion, operatingSystem } from "@/utility-functions/platform"; import { browserVersion, operatingSystem } from "@/utility-functions/platform";
import { stripIndents } from "@/utility-functions/strip-indents"; import { stripIndents } from "@/utility-functions/strip-indents";
import { type Editor } from "@/wasm-communication/editor"; import { type Editor } from "@/wasm-communication/editor";
import type { TextLabel } from "@/wasm-communication/messages"; import { type TextButtonWidget, type TextLabel, type WidgetLayout, Widget, DisplayDialogPanic } from "@/wasm-communication/messages";
import { type TextButtonWidget, type WidgetLayout, Widget, DisplayDialogPanic } from "@/wasm-communication/messages";
export function createPanicManager(editor: Editor, dialogState: DialogState): void { export function createPanicManager(editor: Editor, dialogState: DialogState): void {
// Code panic dialog and console error // Code panic dialog and console error

View file

@ -4,8 +4,7 @@
import { blobToBase64 } from "@/utility-functions/files"; import { blobToBase64 } from "@/utility-functions/files";
import { type RequestResult, requestWithUploadDownloadProgress } from "@/utility-functions/network"; import { type RequestResult, requestWithUploadDownloadProgress } from "@/utility-functions/network";
import { type Editor } from "@/wasm-communication/editor"; import { type Editor } from "@/wasm-communication/editor";
import type { XY } from "@/wasm-communication/messages"; import { type ImaginateGenerationParameters, type XY } from "@/wasm-communication/messages";
import { type ImaginateGenerationParameters } from "@/wasm-communication/messages";
const MAX_POLLING_RETRIES = 4; const MAX_POLLING_RETRIES = 4;
const SERVER_STATUS_CHECK_TIMEOUT = 5000; const SERVER_STATUS_CHECK_TIMEOUT = 5000;

View file

@ -496,7 +496,7 @@ const mouseCursorIconCSSNames = {
Rotate: "custom-rotate", Rotate: "custom-rotate",
} as const; } as const;
export type MouseCursor = keyof typeof mouseCursorIconCSSNames; export type MouseCursor = keyof typeof mouseCursorIconCSSNames;
export type MouseCursorIcon = typeof mouseCursorIconCSSNames[MouseCursor]; export type MouseCursorIcon = (typeof mouseCursorIconCSSNames)[MouseCursor];
export class UpdateMouseCursor extends JsMessage { export class UpdateMouseCursor extends JsMessage {
@Transform(({ value }: { value: MouseCursor }) => mouseCursorIconCSSNames[value] || "alias") @Transform(({ value }: { value: MouseCursor }) => mouseCursorIconCSSNames[value] || "alias")
@ -1169,7 +1169,7 @@ const widgetSubTypes = [
{ value: TextLabel, name: "TextLabel" }, { value: TextLabel, name: "TextLabel" },
{ value: PivotAssist, name: "PivotAssist" }, { value: PivotAssist, name: "PivotAssist" },
]; ];
export type WidgetPropsSet = InstanceType<typeof widgetSubTypes[number]["value"]>; export type WidgetPropsSet = InstanceType<(typeof widgetSubTypes)[number]["value"]>;
export class Widget { export class Widget {
constructor(props: WidgetPropsSet, widgetId: bigint) { constructor(props: WidgetPropsSet, widgetId: bigint) {

View file

@ -14,7 +14,7 @@ documentation = "https://docs.rs/dyn-any"
[dependencies] [dependencies]
dyn-any-derive = { path = "derive", version = "0.2.0", optional = true } dyn-any-derive = { path = "derive", version = "0.2.0", optional = true }
log = { version = "0.4", optional = true } log = { version = "0.4", optional = true }
glam = { version = "0.22", optional = true } glam = { version = "0.22", optional = true, default-features = false }
[features] [features]
derive = ["dyn-any-derive"] derive = ["dyn-any-derive"]
@ -25,7 +25,7 @@ rc = []
glam = ["dep:glam"] glam = ["dep:glam"]
alloc = [] alloc = []
large-atomics = [] large-atomics = []
std = ["alloc", "rc"] std = ["alloc", "rc", "glam/default"]
default = ["std", "large-atomics"] default = ["std", "large-atomics"]
[package.metadata.docs.rs] [package.metadata.docs.rs]

View file

@ -5,7 +5,7 @@ pub async fn compile<I, O>(network: NodeNetwork) -> Result<Vec<u8>, reqwest::Err
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let compile_request = CompileRequest::new(network, std::any::type_name::<I>().to_owned(), std::any::type_name::<O>().to_owned()); let compile_request = CompileRequest::new(network, std::any::type_name::<I>().to_owned(), std::any::type_name::<O>().to_owned());
let response = client.post("http://localhost:3000/compile/spriv").json(&compile_request).send(); let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send();
let response = response.await?; let response = response.await?;
response.bytes().await.map(|b| b.to_vec()) response.bytes().await.map(|b| b.to_vec())
} }

View file

@ -1,6 +1,5 @@
use gpu_compiler_bin_wrapper::CompileRequest; use gpu_compiler_bin_wrapper::CompileRequest;
use graph_craft::concrete; use graph_craft::concrete;
use graph_craft::document::value::TaggedValue;
use graph_craft::document::*; use graph_craft::document::*;
use graph_craft::*; use graph_craft::*;
@ -18,13 +17,7 @@ fn main() {
0, 0,
DocumentNode { DocumentNode {
name: "Inc Node".into(), name: "Inc Node".into(),
inputs: vec![ inputs: vec![NodeInput::Network(concrete!(u32))],
NodeInput::Network(concrete!(u32)),
NodeInput::Value {
tagged_value: TaggedValue::U32(1),
exposed: false,
},
],
implementation: DocumentNodeImplementation::Network(add_network()), implementation: DocumentNodeImplementation::Network(add_network()),
metadata: DocumentNodeMetadata::default(), metadata: DocumentNodeMetadata::default(),
}, },
@ -34,13 +27,13 @@ fn main() {
}; };
let compile_request = CompileRequest::new(network, "u32".to_owned(), "u32".to_owned()); let compile_request = CompileRequest::new(network, "u32".to_owned(), "u32".to_owned());
let response = client.post("http://localhost:3000/compile/spriv").json(&compile_request).send().unwrap(); let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send().unwrap();
println!("response: {:?}", response); println!("response: {:?}", response);
} }
fn add_network() -> NodeNetwork { fn add_network() -> NodeNetwork {
NodeNetwork { NodeNetwork {
inputs: vec![0, 0], inputs: vec![0],
outputs: vec![NodeOutput::new(1, 0)], outputs: vec![NodeOutput::new(1, 0)],
disabled: vec![], disabled: vec![],
previous_outputs: None, previous_outputs: None,
@ -48,10 +41,10 @@ fn add_network() -> NodeNetwork {
( (
0, 0,
DocumentNode { DocumentNode {
name: "Cons".into(), name: "Dup".into(),
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::Network(concrete!(u32))], inputs: vec![NodeInput::Network(concrete!(u32))],
metadata: DocumentNodeMetadata::default(), metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::structural::ConsNode")), implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::DupNode")),
}, },
), ),
( (

View file

@ -23,14 +23,14 @@ async fn main() {
let app = Router::new() let app = Router::new()
.route("/", get(|| async { "Hello from compilation server!" })) .route("/", get(|| async { "Hello from compilation server!" }))
.route("/compile", get(|| async { "Supported targets: spirv" })) .route("/compile", get(|| async { "Supported targets: spirv" }))
.route("/compile/spriv", post(post_compile_spriv)) .route("/compile/spirv", post(post_compile_spirv))
.with_state(shared_state); .with_state(shared_state);
// run it with hyper on localhost:3000 // run it with hyper on localhost:3000
axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
} }
async fn post_compile_spriv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> { async fn post_compile_spirv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> {
let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml"; let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml";
compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| { compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| {
eprintln!("compilation failed: {}", e); eprintln!("compilation failed: {}", e);

View file

@ -9,20 +9,20 @@ license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features] [features]
std = ["dyn-any", "dyn-any/std", "alloc"] std = ["dyn-any", "dyn-any/std", "alloc", "glam/std", "specta"]
default = ["async", "serde", "kurbo", "log", "std"] default = ["async", "serde", "kurbo", "log", "std"]
log = ["dep:log"] log = ["dep:log"]
serde = ["dep:serde", "glam/serde"] serde = ["dep:serde", "glam/serde"]
gpu = ["spirv-std", "bytemuck", "glam/bytemuck", "dyn-any"] gpu = ["spirv-std", "bytemuck", "glam/bytemuck", "dyn-any", "glam/libm"]
async = ["async-trait", "alloc"] async = ["async-trait", "alloc"]
nightly = [] nightly = []
alloc = ["dyn-any", "bezier-rs"] alloc = ["dyn-any", "bezier-rs", "once_cell"]
type_id_logging = [] type_id_logging = []
[dependencies] [dependencies]
dyn-any = {path = "../../libraries/dyn-any", features = ["derive", "glam"], optional = true, default-features = false } dyn-any = {path = "../../libraries/dyn-any", features = ["derive", "glam"], optional = true, default-features = false }
spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu", features = ["glam"] , optional = true} spirv-std = { version = "0.5", features = ["glam"] , optional = true}
bytemuck = {version = "1.8", features = ["derive"], optional = true} bytemuck = {version = "1.8", features = ["derive"], optional = true}
async-trait = {version = "0.1", optional = true} async-trait = {version = "0.1", optional = true}
serde = {version = "1.0", features = ["derive"], optional = true, default-features = false } serde = {version = "1.0", features = ["derive"], optional = true, default-features = false }
@ -32,10 +32,11 @@ bezier-rs = { path = "../../libraries/bezier-rs", optional = true }
kurbo = { git = "https://github.com/linebender/kurbo.git", features = [ kurbo = { git = "https://github.com/linebender/kurbo.git", features = [
"serde", "serde",
], optional = true } ], optional = true }
glam = { version = "^0.22", default-features = false, features = ["scalar-math", "libm"]}
rand_chacha = "0.3.1" rand_chacha = "0.3.1"
spin = "0.9.2" spin = "0.9.2"
glam = { version = "^0.22", default-features = false, features = ["scalar-math"]}
node-macro = {path = "../node-macro"} node-macro = {path = "../node-macro"}
specta.workspace = true specta.workspace = true
once_cell = { version = "1.17.0", default-features = false } specta.optional = true
once_cell = { version = "1.17.0", default-features = false, optional = true }
# forma = { version = "0.1.0", package = "forma-render" } # forma = { version = "0.1.0", package = "forma-render" }

View file

@ -11,6 +11,7 @@ pub mod consts;
pub mod generic; pub mod generic;
pub mod ops; pub mod ops;
pub mod structural; pub mod structural;
#[cfg(feature = "std")]
pub mod uuid; pub mod uuid;
pub mod value; pub mod value;
@ -22,6 +23,8 @@ pub mod raster;
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
pub mod vector; pub mod vector;
pub mod quantization;
use core::any::TypeId; use core::any::TypeId;
pub use raster::Color; pub use raster::Color;
@ -33,6 +36,7 @@ pub trait Node<'i, Input: 'i>: 'i {
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
mod types; mod types;
#[cfg(feature = "alloc")]
pub use types::*; pub use types::*;
pub trait NodeIO<'i, Input: 'i>: 'i + Node<'i, Input> pub trait NodeIO<'i, Input: 'i>: 'i + Node<'i, Input>

View file

@ -0,0 +1,83 @@
use crate::raster::Color;
use crate::Node;
use dyn_any::{DynAny, StaticType};
#[derive(Clone, Debug, DynAny, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Quantization {
pub fn_index: usize,
pub a: f32,
pub b: f32,
pub c: f32,
pub d: f32,
}
impl core::hash::Hash for Quantization {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.fn_index.hash(state);
self.a.to_bits().hash(state);
self.b.to_bits().hash(state);
self.c.to_bits().hash(state);
self.d.to_bits().hash(state);
}
}
impl Default for Quantization {
fn default() -> Self {
Self {
fn_index: Default::default(),
a: 1.,
b: Default::default(),
c: Default::default(),
d: Default::default(),
}
}
}
pub type QuantizationChannels = [Quantization; 4];
fn quantize(value: f32, quantization: &Quantization) -> f32 {
let Quantization { fn_index, a, b, c, d } = quantization;
match fn_index {
1 => ((value + a) * d).abs().ln() * b + c,
_ => a * value + b,
}
}
fn decode(value: f32, quantization: &Quantization) -> f32 {
let Quantization { fn_index, a, b, c, d } = quantization;
match fn_index {
1 => -(-c / b).exp() * (a * d * (c / b).exp() - (value / b).exp()) / d,
_ => (value - b) / a,
}
}
pub struct QuantizeNode<Quantization> {
quantization: Quantization,
}
#[node_macro::node_fn(QuantizeNode)]
fn quantize_fn<'a>(color: Color, quantization: [Quantization; 4]) -> Color {
let quant = quantization.as_slice();
let r = quantize(color.r(), &quant[0]);
let g = quantize(color.g(), &quant[1]);
let b = quantize(color.b(), &quant[2]);
let a = quantize(color.a(), &quant[3]);
Color::from_rgbaf32_unchecked(r, g, b, a)
}
pub struct DeQuantizeNode<Quantization> {
quantization: Quantization,
}
#[node_macro::node_fn(DeQuantizeNode)]
fn dequantize_fn<'a>(color: Color, quantization: [Quantization; 4]) -> Color {
let quant = quantization.as_slice();
let r = decode(color.r(), &quant[0]);
let g = decode(color.g(), &quant[1]);
let b = decode(color.b(), &quant[2]);
let a = decode(color.a(), &quant[3]);
Color::from_rgbaf32_unchecked(r, g, b, a)
}

View file

@ -2,6 +2,9 @@ use core::{fmt::Debug, marker::PhantomData};
use crate::Node; use crate::Node;
#[cfg(target_arch = "spirv")]
use spirv_std::num_traits::float::Float;
pub mod color; pub mod color;
pub use self::color::Color; pub use self::color::Color;
@ -180,6 +183,7 @@ impl<'a> ImageWindowIterator<'a> {
} }
} }
#[cfg(not(target_arch = "spirv"))]
impl<'a> Iterator for ImageWindowIterator<'a> { impl<'a> Iterator for ImageWindowIterator<'a> {
type Item = (Color, (i32, i32)); type Item = (Color, (i32, i32));
#[inline] #[inline]
@ -194,6 +198,9 @@ impl<'a> Iterator for ImageWindowIterator<'a> {
if self.y > max_y { if self.y > max_y {
return None; return None;
} }
#[cfg(target_arch = "spirv")]
let value = None;
#[cfg(not(target_arch = "spirv"))]
let value = Some((self.image.data[(self.x + self.y * self.image.width) as usize], (self.x as i32 - start_x, self.y as i32 - start_y))); let value = Some((self.image.data[(self.x + self.y * self.image.width) as usize], (self.x as i32 - start_x, self.y as i32 - start_y)));
self.x += 1; self.x += 1;
@ -245,21 +252,49 @@ where
input.for_each(|x| map_node.eval(x)); input.for_each(|x| map_node.eval(x));
} }
#[cfg(target_arch = "spirv")]
const NOTHING: () = ();
use dyn_any::{DynAny, StaticType}; use dyn_any::{DynAny, StaticType};
#[derive(Clone, Debug, PartialEq, DynAny, Default, Copy)] #[derive(Clone, Debug, PartialEq, DynAny, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))] #[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct ImageSlice<'a> { pub struct ImageSlice<'a> {
pub width: u32, pub width: u32,
pub height: u32, pub height: u32,
#[cfg(not(target_arch = "spirv"))]
pub data: &'a [Color], pub data: &'a [Color],
#[cfg(target_arch = "spirv")]
pub data: &'a (),
}
#[allow(clippy::derivable_impls)]
impl<'a> Default for ImageSlice<'a> {
#[cfg(not(target_arch = "spirv"))]
fn default() -> Self {
Self {
width: Default::default(),
height: Default::default(),
data: Default::default(),
}
}
#[cfg(target_arch = "spirv")]
fn default() -> Self {
Self {
width: Default::default(),
height: Default::default(),
data: &NOTHING,
}
}
} }
impl ImageSlice<'_> { impl ImageSlice<'_> {
#[cfg(not(target_arch = "spirv"))]
pub const fn empty() -> Self { pub const fn empty() -> Self {
Self { width: 0, height: 0, data: &[] } Self { width: 0, height: 0, data: &[] }
} }
} }
#[cfg(not(target_arch = "spirv"))]
impl<'a> IntoIterator for ImageSlice<'a> { impl<'a> IntoIterator for ImageSlice<'a> {
type Item = &'a Color; type Item = &'a Color;
type IntoIter = core::slice::Iter<'a, Color>; type IntoIter = core::slice::Iter<'a, Color>;
@ -268,6 +303,7 @@ impl<'a> IntoIterator for ImageSlice<'a> {
} }
} }
#[cfg(not(target_arch = "spirv"))]
impl<'a> IntoIterator for &'a ImageSlice<'a> { impl<'a> IntoIterator for &'a ImageSlice<'a> {
type Item = &'a Color; type Item = &'a Color;
type IntoIter = core::slice::Iter<'a, Color>; type IntoIter = core::slice::Iter<'a, Color>;

View file

@ -4,8 +4,12 @@ use crate::Node;
use core::fmt::Debug; use core::fmt::Debug;
use dyn_any::{DynAny, StaticType}; use dyn_any::{DynAny, StaticType};
#[cfg(target_arch = "spirv")]
use spirv_std::num_traits::float::Float;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny, specta::Type, Hash)] #[cfg_attr(feature = "std", derive(specta::Type))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny, Hash)]
pub enum LuminanceCalculation { pub enum LuminanceCalculation {
#[default] #[default]
SRGB, SRGB,
@ -27,8 +31,8 @@ impl LuminanceCalculation {
} }
} }
impl std::fmt::Display for LuminanceCalculation { impl core::fmt::Display for LuminanceCalculation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self { match self {
LuminanceCalculation::SRGB => write!(f, "sRGB"), LuminanceCalculation::SRGB => write!(f, "sRGB"),
LuminanceCalculation::Perceptual => write!(f, "Perceptual"), LuminanceCalculation::Perceptual => write!(f, "Perceptual"),
@ -73,7 +77,8 @@ impl BlendMode {
} }
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny, specta::Type, Hash)] #[cfg_attr(feature = "std", derive(specta::Type))]
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, DynAny, Hash)]
pub enum BlendMode { pub enum BlendMode {
#[default] #[default]
// Basic group // Basic group
@ -116,8 +121,8 @@ pub enum BlendMode {
Luminosity, Luminosity,
} }
impl std::fmt::Display for BlendMode { impl core::fmt::Display for BlendMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self { match self {
BlendMode::Normal => write!(f, "Normal"), BlendMode::Normal => write!(f, "Normal"),
@ -368,8 +373,7 @@ fn blend_node(input: (Color, Color), blend_mode: BlendMode, opacity: f64) -> Col
BlendMode::Color => backdrop.blend_color(source_color), BlendMode::Color => backdrop.blend_color(source_color),
BlendMode::Luminosity => backdrop.blend_luminosity(source_color), BlendMode::Luminosity => backdrop.blend_luminosity(source_color),
} }
.lerp(backdrop, actual_opacity) .lerp(backdrop, actual_opacity);
.unwrap();
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]

View file

@ -20,7 +20,8 @@ use bytemuck::{Pod, Zeroable};
#[repr(C)] #[repr(C)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "gpu", derive(Pod, Zeroable))] #[cfg_attr(feature = "gpu", derive(Pod, Zeroable))]
#[derive(Debug, Clone, Copy, PartialEq, Default, DynAny, specta::Type)] #[cfg_attr(feature = "std", derive(specta::Type))]
#[derive(Debug, Default, Clone, Copy, PartialEq, DynAny)]
pub struct Color { pub struct Color {
red: f32, red: f32,
green: f32, green: f32,
@ -30,7 +31,7 @@ pub struct Color {
#[allow(clippy::derived_hash_with_manual_eq)] #[allow(clippy::derived_hash_with_manual_eq)]
impl Hash for Color { impl Hash for Color {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.red.to_bits().hash(state); self.red.to_bits().hash(state);
self.green.to_bits().hash(state); self.green.to_bits().hash(state);
self.blue.to_bits().hash(state); self.blue.to_bits().hash(state);
@ -119,7 +120,6 @@ impl Color {
/// use graphene_core::raster::color::Color; /// use graphene_core::raster::color::Color;
/// let color = Color::from_hsla(0.5, 0.2, 0.3, 1.); /// let color = Color::from_hsla(0.5, 0.2, 0.3, 1.);
/// ``` /// ```
#[cfg(not(target_arch = "spirv"))]
pub fn from_hsla(hue: f32, saturation: f32, lightness: f32, alpha: f32) -> Color { pub fn from_hsla(hue: f32, saturation: f32, lightness: f32, alpha: f32) -> Color {
let temp1 = if lightness < 0.5 { let temp1 = if lightness < 0.5 {
lightness * (saturation + 1.) lightness * (saturation + 1.)
@ -127,12 +127,16 @@ impl Color {
lightness + saturation - lightness * saturation lightness + saturation - lightness * saturation
}; };
let temp2 = 2. * lightness - temp1; let temp2 = 2. * lightness - temp1;
#[cfg(not(target_arch = "spirv"))]
let rem = |x: f32| x.rem_euclid(1.);
#[cfg(target_arch = "spirv")]
let rem = |x: f32| x.rem_euclid(&1.);
let mut red = (hue + 1. / 3.).rem_euclid(1.); let mut red = rem(hue + 1. / 3.);
let mut green = hue.rem_euclid(1.); let mut green = rem(hue);
let mut blue = (hue - 1. / 3.).rem_euclid(1.); let mut blue = rem(hue - 1. / 3.);
for channel in [&mut red, &mut green, &mut blue] { fn map_channel(channel: &mut f32, temp2: f32, temp1: f32) {
*channel = if *channel * 6. < 1. { *channel = if *channel * 6. < 1. {
temp2 + (temp1 - temp2) * 6. * *channel temp2 + (temp1 - temp2) * 6. * *channel
} else if *channel * 2. < 1. { } else if *channel * 2. < 1. {
@ -144,6 +148,9 @@ impl Color {
} }
.clamp(0., 1.); .clamp(0., 1.);
} }
map_channel(&mut red, temp2, temp1);
map_channel(&mut green, temp2, temp1);
map_channel(&mut blue, temp2, temp1);
Color { red, green, blue, alpha } Color { red, green, blue, alpha }
} }
@ -427,6 +434,7 @@ impl Color {
/// let color = Color::from_rgba8(0x7C, 0x67, 0xFA, 0x61); /// let color = Color::from_rgba8(0x7C, 0x67, 0xFA, 0x61);
/// assert!("7C67FA61" == color.rgba_hex()) /// assert!("7C67FA61" == color.rgba_hex())
/// ``` /// ```
#[cfg(feature = "std")]
pub fn rgba_hex(&self) -> String { pub fn rgba_hex(&self) -> String {
format!( format!(
"{:02X?}{:02X?}{:02X?}{:02X?}", "{:02X?}{:02X?}{:02X?}{:02X?}",
@ -443,6 +451,7 @@ impl Color {
/// let color = Color::from_rgba8(0x7C, 0x67, 0xFA, 0x61); /// let color = Color::from_rgba8(0x7C, 0x67, 0xFA, 0x61);
/// assert!("7C67FA" == color.rgb_hex()) /// assert!("7C67FA" == color.rgb_hex())
/// ``` /// ```
#[cfg(feature = "std")]
pub fn rgb_hex(&self) -> String { pub fn rgb_hex(&self) -> String {
format!("{:02X?}{:02X?}{:02X?}", (self.r() * 255.) as u8, (self.g() * 255.) as u8, (self.b() * 255.) as u8,) format!("{:02X?}{:02X?}{:02X?}", (self.r() * 255.) as u8, (self.g() * 255.) as u8, (self.b() * 255.) as u8,)
} }
@ -536,9 +545,9 @@ impl Color {
/// Linearly interpolates between two colors based on t. /// Linearly interpolates between two colors based on t.
/// ///
/// T must be between 0 and 1. /// T must be between 0 and 1.
pub fn lerp(self, other: Color, t: f32) -> Option<Self> { pub fn lerp(self, other: Color, t: f32) -> Self {
assert!((0. ..=1.).contains(&t)); assert!((0. ..=1.).contains(&t));
Color::from_rgbaf32( Color::from_rgbaf32_unchecked(
self.red + ((other.red - self.red) * t), self.red + ((other.red - self.red) * t),
self.green + ((other.green - self.green) * t), self.green + ((other.green - self.green) * t),
self.blue + ((other.blue - self.blue) * t), self.blue + ((other.blue - self.blue) * t),
@ -600,6 +609,7 @@ impl Color {
blue: f(self.blue, other.blue), blue: f(self.blue, other.blue),
alpha: self.alpha, alpha: self.alpha,
}; };
#[cfg(feature = "log")]
if *self == Color::RED { if *self == Color::RED {
debug!("{} {} {} {}", color.red, color.green, color.blue, color.alpha); debug!("{} {} {} {}", color.red, color.green, color.blue, color.alpha);
} }

View file

@ -133,8 +133,8 @@ impl Gradient {
((time - self.positions[index].0) / self.positions.get(index + 1).map(|end| end.0 - self.positions[index].0).unwrap_or_default()) as f32, ((time - self.positions[index].0) / self.positions.get(index + 1).map(|end| end.0 - self.positions[index].0).unwrap_or_default()) as f32,
), ),
// Use the start or the end colour if applicable // Use the start or the end colour if applicable
(Some(v), _) | (_, Some(v)) => Some(v), (Some(v), _) | (_, Some(v)) => v,
_ => Some(Color::WHITE), _ => Color::WHITE,
}; };
// Compute the correct index to keep the positions in order // Compute the correct index to keep the positions in order
@ -146,7 +146,7 @@ impl Gradient {
let new_color = get_color(index - 1, new_position); let new_color = get_color(index - 1, new_position);
// Insert the new stop // Insert the new stop
self.positions.insert(index, (new_position, new_color)); self.positions.insert(index, (new_position, Some(new_color)));
Some(index) Some(index)
} }

View file

@ -2,6 +2,12 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 3
[[package]]
name = "Inflector"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3"
[[package]] [[package]]
name = "ahash" name = "ahash"
version = "0.7.6" version = "0.7.6"
@ -274,6 +280,15 @@ dependencies = [
"crypto-common", "crypto-common",
] ]
[[package]]
name = "document-features"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e493c573fce17f00dcab13b6ac057994f3ce17d1af4dc39bfd482b83c6eb6157"
dependencies = [
"litrs",
]
[[package]] [[package]]
name = "dyn-any" name = "dyn-any"
version = "0.2.1" version = "0.2.1"
@ -365,7 +380,6 @@ version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12f597d56c1bd55a811a1be189459e8fad2bbc272616375602443bdfb37fa774" checksum = "12f597d56c1bd55a811a1be189459e8fad2bbc272616375602443bdfb37fa774"
dependencies = [ dependencies = [
"num-traits",
"serde", "serde",
] ]
@ -430,6 +444,7 @@ dependencies = [
"num-traits", "num-traits",
"rand_chacha", "rand_chacha",
"serde", "serde",
"specta",
] ]
[[package]] [[package]]
@ -443,7 +458,9 @@ dependencies = [
"kurbo", "kurbo",
"log", "log",
"node-macro", "node-macro",
"once_cell",
"serde", "serde",
"specta",
] ]
[[package]] [[package]]
@ -597,6 +614,12 @@ dependencies = [
"cc", "cc",
] ]
[[package]]
name = "litrs"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9275e0933cf8bb20f008924c0cb07a0692fe54d8064996520bf998de9eb79aa"
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.17" version = "0.4.17"
@ -616,6 +639,7 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
name = "node-macro" name = "node-macro"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"proc-macro2",
"quote", "quote",
"syn", "syn",
] ]
@ -651,9 +675,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.16.0" version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]] [[package]]
name = "parse-zoneinfo" name = "parse-zoneinfo"
@ -664,6 +688,12 @@ dependencies = [
"regex", "regex",
] ]
[[package]]
name = "paste"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba"
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.2.0" version = "2.2.0"
@ -873,14 +903,15 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc_codegen_spirv" name = "rustc_codegen_spirv"
version = "0.4.0" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ea7712b3de402aa159d0073fba3e025836f921847490f3663743eb8d6bf0220" checksum = "1b803b49618cdde99e1065af9415f489b993374765e30a6b80f2bea2cca65914"
dependencies = [ dependencies = [
"ar", "ar",
"either", "either",
"hashbrown 0.11.2", "hashbrown 0.11.2",
"indexmap", "indexmap",
"lazy_static",
"libc", "libc",
"num-traits", "num-traits",
"once_cell", "once_cell",
@ -899,9 +930,9 @@ dependencies = [
[[package]] [[package]]
name = "rustc_codegen_spirv-types" name = "rustc_codegen_spirv-types"
version = "0.4.0" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f00bec144cf0240b503dce9fa0dd3bca1e94da3726ddbf2df055bc6155430c1b" checksum = "330aedc6b09b9bf3c58cc7fb942c1377310a9cff00fae9e4f6cc09a7a28f542e"
dependencies = [ dependencies = [
"rspirv", "rspirv",
"serde", "serde",
@ -1004,6 +1035,32 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "specta"
version = "0.0.6"
source = "git+https://github.com/oscartbeaumont/rspc?rev=9725ddbfe40183debc055b88c37910eb6f818eae#9725ddbfe40183debc055b88c37910eb6f818eae"
dependencies = [
"document-features",
"glam",
"once_cell",
"paste",
"serde",
"serde_json",
"specta-macros",
]
[[package]]
name = "specta-macros"
version = "0.0.6"
source = "git+https://github.com/oscartbeaumont/rspc?rev=9725ddbfe40183debc055b88c37910eb6f818eae#9725ddbfe40183debc055b88c37910eb6f818eae"
dependencies = [
"Inflector",
"proc-macro2",
"quote",
"syn",
"termcolor",
]
[[package]] [[package]]
name = "spirt" name = "spirt"
version = "0.1.0" version = "0.1.0"
@ -1034,9 +1091,9 @@ dependencies = [
[[package]] [[package]]
name = "spirv-builder" name = "spirv-builder"
version = "0.4.0" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5095b11ee4033cfc146df5bc2ad5671c3c3d900ac46a3c30282bdf387d902250" checksum = "93f656f97ac742e5603843d2ea3ea644cdf5630635b3320ec87666385766e7ab"
dependencies = [ dependencies = [
"memchr", "memchr",
"raw-string", "raw-string",

View file

@ -24,7 +24,7 @@ base64 = "0.13"
bytemuck = { version = "1.8" } bytemuck = { version = "1.8" }
nvtx = { version = "1.1.1", optional = true } nvtx = { version = "1.1.1", optional = true }
tempfile = "3" tempfile = "3"
spirv-builder = { version = "0.4", default-features = false, features=["use-installed-tools"] } spirv-builder = { version = "0.5", default-features = false, features=["use-installed-tools"] }
tera = { version = "1.17.1" } tera = { version = "1.17.1" }
anyhow = "1.0.66" anyhow = "1.0.66"
serde_json = "1.0.91" serde_json = "1.0.91"

View file

@ -33,7 +33,7 @@ pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &
if !output.status.success() { if !output.status.success() {
return Err(anyhow::anyhow!("cargo failed: {}", String::from_utf8_lossy(&output.stderr))); return Err(anyhow::anyhow!("cargo failed: {}", String::from_utf8_lossy(&output.stderr)));
} }
Ok(output.stdout) Ok(std::fs::read(compile_dir.unwrap().to_owned() + "/shader.spv")?)
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]

View file

@ -1,3 +1,3 @@
[toolchain] [toolchain]
channel = "nightly-2022-10-29" channel = "nightly-2022-12-18"
components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"] components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"]

View file

@ -30,8 +30,8 @@ pub fn create_files(matadata: &Metadata, network: &ProtoNetwork, compile_dir: &P
let cargo_toml = create_cargo_toml(matadata)?; let cargo_toml = create_cargo_toml(matadata)?;
std::fs::write(cargo_file, cargo_toml)?; std::fs::write(cargo_file, cargo_toml)?;
let toolchain_file = compile_dir.join("rust-toolchain"); let toolchain_file = compile_dir.join("rust-toolchain.toml");
let toolchain = include_str!("templates/rust-toolchain"); let toolchain = include_str!("templates/rust-toolchain.toml");
std::fs::write(toolchain_file, toolchain)?; std::fs::write(toolchain_file, toolchain)?;
// create src dir // create src dir
@ -69,7 +69,7 @@ pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str
nodes.push(Node { nodes.push(Node {
id, id,
fqn: fqn.to_string(), fqn: fqn.to_string().split("<").next().unwrap().to_owned(),
args: node.construction_args.new_function_args(), args: node.construction_args.new_function_args(),
}); });
} }

View file

@ -11,7 +11,7 @@ fn main() -> anyhow::Result<()> {
let compile_dir = std::env::args().nth(3).map(|x| std::path::PathBuf::from(&x)).unwrap_or(tempfile::tempdir()?.into_path()); let compile_dir = std::env::args().nth(3).map(|x| std::path::PathBuf::from(&x)).unwrap_or(tempfile::tempdir()?.into_path());
let network: NodeNetwork = serde_json::from_reader(&mut stdin)?; let network: NodeNetwork = serde_json::from_reader(&mut stdin)?;
let compiler = graph_craft::executor::Compiler {}; let compiler = graph_craft::executor::Compiler {};
let proto_network = compiler.compile(network, true); let proto_network = compiler.compile_single(network, true).unwrap();
dbg!(&compile_dir); dbg!(&compile_dir);
let metadata = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]); let metadata = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
@ -20,7 +20,9 @@ fn main() -> anyhow::Result<()> {
let result = compiler::compile(&compile_dir)?; let result = compiler::compile(&compile_dir)?;
let bytes = std::fs::read(result.module.unwrap_single())?; let bytes = std::fs::read(result.module.unwrap_single())?;
stdout.write_all(&bytes)?; // TODO: properly resolve this
let spirv_path = compile_dir.join("shader.spv");
std::fs::write(&spirv_path, &bytes)?;
Ok(()) Ok(())
} }

View file

@ -13,5 +13,5 @@ crate-type = ["dylib", "lib"]
libm = { git = "https://github.com/rust-lang/libm", tag = "0.2.5" } libm = { git = "https://github.com/rust-lang/libm", tag = "0.2.5" }
[dependencies] [dependencies]
spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu" , features= ["glam"]} spirv-std = { version = "0.5" , features= ["glam"]}
graphene-core = {path = "{{gcore_path}}", default-features = false, features = ["gpu"]} graphene-core = {path = "{{gcore_path}}", default-features = false, features = ["gpu"]}

View file

@ -1,3 +1,3 @@
[toolchain] [toolchain]
channel = "nightly-2022-10-29" channel = "nightly-2022-12-18"
components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"] components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"]

View file

@ -17,13 +17,13 @@ pub mod gpu {
#[spirv(global_invocation_id)] global_id: UVec3, #[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[{{input_type}}], #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[{{input_type}}],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [{{output_type}}], #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [{{output_type}}],
#[spirv(push_constant)] push_consts: &graphene_core::gpu::PushConstants, //#[spirv(push_constant)] push_consts: &graphene_core::gpu::PushConstants,
) { ) {
let gid = global_id.x as usize; let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers. // Only process up to n, which is the length of the buffers.
if global_id.x < push_consts.n { //if global_id.x < push_consts.n {
y[gid] = node_graph(a[gid]); y[gid] = node_graph(a[gid]);
} //}
} }
fn node_graph(input: {{input_type}}) -> {{output_type}} { fn node_graph(input: {{input_type}}) -> {{output_type}} {

View file

@ -44,6 +44,7 @@ pub enum TaggedValue {
FillType(graphene_core::vector::style::FillType), FillType(graphene_core::vector::style::FillType),
GradientType(graphene_core::vector::style::GradientType), GradientType(graphene_core::vector::style::GradientType),
GradientPositions(Vec<(f64, Option<graphene_core::Color>)>), GradientPositions(Vec<(f64, Option<graphene_core::Color>)>),
Quantization(graphene_core::quantization::QuantizationChannels),
} }
#[allow(clippy::derived_hash_with_manual_eq)] #[allow(clippy::derived_hash_with_manual_eq)]
@ -175,6 +176,10 @@ impl Hash for TaggedValue {
color.hash(state); color.hash(state);
} }
} }
Self::Quantization(quantized_image) => {
31.hash(state);
quantized_image.hash(state);
}
} }
} }
} }
@ -213,6 +218,7 @@ impl<'a> TaggedValue {
TaggedValue::FillType(x) => Box::new(x), TaggedValue::FillType(x) => Box::new(x),
TaggedValue::GradientType(x) => Box::new(x), TaggedValue::GradientType(x) => Box::new(x),
TaggedValue::GradientPositions(x) => Box::new(x), TaggedValue::GradientPositions(x) => Box::new(x),
TaggedValue::Quantization(x) => Box::new(x),
} }
} }
@ -250,6 +256,7 @@ impl<'a> TaggedValue {
TaggedValue::FillType(_) => concrete!(graphene_core::vector::style::FillType), TaggedValue::FillType(_) => concrete!(graphene_core::vector::style::FillType),
TaggedValue::GradientType(_) => concrete!(graphene_core::vector::style::GradientType), TaggedValue::GradientType(_) => concrete!(graphene_core::vector::style::GradientType),
TaggedValue::GradientPositions(_) => concrete!(Vec<(f64, Option<graphene_core::Color>)>), TaggedValue::GradientPositions(_) => concrete!(Vec<(f64, Option<graphene_core::Color>)>),
TaggedValue::Quantization(_) => concrete!(graphene_core::quantization::QuantizationChannels),
} }
} }
} }

View file

@ -1,121 +1,61 @@
use graph_craft::document::*; use graph_craft::document::*;
use graph_craft::proto::*; use graphene_core::raster::*;
use graphene_core::raster::Image;
use graphene_core::value::ValueNode; use graphene_core::value::ValueNode;
use graphene_core::Node; use graphene_core::*;
use bytemuck::Pod; use bytemuck::Pod;
use core::marker::PhantomData; use core::marker::PhantomData;
use dyn_any::StaticTypeSized; use dyn_any::StaticTypeSized;
pub struct MapGpuNode<NN: Node<()>, I: IntoIterator<Item = S>, S: StaticTypeSized + Sync + Send + Pod, O: StaticTypeSized + Sync + Send + Pod>(pub NN, PhantomData<(S, I, O)>); pub struct MapGpuNode<O, Network> {
network: Network,
impl<'n, I: IntoIterator<Item = S>, NN: Node<(), Output = &'n NodeNetwork> + Copy, S: StaticTypeSized + Sync + Send + Pod, O: StaticTypeSized + Sync + Send + Pod> Node<I> _o: PhantomData<O>,
for &MapGpuNode<NN, I, S, O>
{
type Output = Vec<O>;
fn eval(self, input: I) -> Self::Output {
let network = self.0.eval(());
map_gpu_impl(network, input)
}
} }
fn map_gpu_impl<I: IntoIterator<Item = S>, S: StaticTypeSized + Sync + Send + Pod, O: StaticTypeSized + Sync + Send + Pod>(network: &NodeNetwork, input: I) -> Vec<O> { #[node_macro::node_fn(MapGpuNode<_O>)]
fn map_gpu<I: IntoIterator<Item = S>, S: StaticTypeSized + Sync + Send + Pod, _O: StaticTypeSized + Sync + Send + Pod>(input: I, network: &'any_input NodeNetwork) -> Vec<_O> {
use graph_craft::executor::Executor; use graph_craft::executor::Executor;
let bytes = compilation_client::compile_sync::<S, O>(network.clone()).unwrap(); let bytes = compilation_client::compile_sync::<S, _O>(network.clone()).unwrap();
let words = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const u32, bytes.len() / 4) }; let words = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const u32, bytes.len() / 4) };
use wgpu_executor::{Context, GpuExecutor}; use wgpu_executor::{Context, GpuExecutor};
let executor: GpuExecutor<S, O> = GpuExecutor::new(Context::new_sync().unwrap(), words.into(), "gpu::eval".into()).unwrap(); let executor: GpuExecutor<S, _O> = GpuExecutor::new(Context::new_sync().unwrap(), words.into(), "gpu::eval".into()).unwrap();
let data: Vec<_> = input.into_iter().collect(); let data: Vec<_> = input.into_iter().collect();
let result = executor.execute(Box::new(data)).unwrap(); let result = executor.execute(Box::new(data)).unwrap();
let result = dyn_any::downcast::<Vec<O>>(result).unwrap(); let result = dyn_any::downcast::<Vec<_O>>(result).unwrap();
*result *result
} }
impl<'n, I: IntoIterator<Item = S>, NN: Node<(), Output = &'n NodeNetwork> + Copy, S: StaticTypeSized + Sync + Send + Pod, O: StaticTypeSized + Sync + Send + Pod> Node<I> for MapGpuNode<NN, I, S, O> {
type Output = Vec<O>; pub struct MapGpuSingleImageNode<N> {
fn eval(self, input: I) -> Self::Output { node: N,
let network = self.0.eval(());
map_gpu_impl(network, input)
}
} }
impl<I: IntoIterator<Item = S>, NN: Node<()>, S: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Sync + Send + Pod> MapGpuNode<NN, I, S, O> { #[node_macro::node_fn(MapGpuSingleImageNode)]
pub const fn new(network: NN) -> Self { fn map_gpu_single_image(input: Image, node: String) -> Image {
MapGpuNode(network, PhantomData) use graph_craft::document::*;
} use graph_craft::NodeIdentifier;
}
let identifier = NodeIdentifier { name: std::borrow::Cow::Owned(node) };
pub struct MapGpuSingleImageNode<NN: Node<(), Output = String>>(pub NN);
let network = NodeNetwork {
impl<NN: Node<(), Output = String> + Copy> Node<Image> for MapGpuSingleImageNode<NN> { inputs: vec![0],
type Output = Image; disabled: vec![],
fn eval(self, input: Image) -> Self::Output { previous_outputs: None,
let node = self.0.eval(()); outputs: vec![NodeOutput::new(0, 0)],
use graph_craft::document::*; nodes: [(
0,
let identifier = NodeIdentifier { DocumentNode {
name: std::borrow::Cow::Owned(node), name: "Image filter Node".into(),
types: std::borrow::Cow::Borrowed(&[]), inputs: vec![NodeInput::Network(concrete!(Color))],
}; implementation: DocumentNodeImplementation::Unresolved(identifier),
metadata: DocumentNodeMetadata::default(),
let network = NodeNetwork { },
inputs: vec![0], )]
disabled: vec![], .into_iter()
previous_output: None, .collect(),
output: 0, };
nodes: [(
0, let value_network = ValueNode::new(network);
DocumentNode { let map_node = MapGpuNode::new(value_network);
name: "Image filter Node".into(), let data = map_node.eval(input.data.clone());
inputs: vec![NodeInput::Network], Image { data, ..input }
implementation: DocumentNodeImplementation::Unresolved(identifier),
metadata: DocumentNodeMetadata::default(),
},
)]
.into_iter()
.collect(),
};
let value_network = ValueNode::new(network);
let map_node = MapGpuNode::new(&value_network);
let data = map_node.eval(input.data.clone());
Image { data, ..input }
}
}
impl<NN: Node<(), Output = String> + Copy> Node<Image> for &MapGpuSingleImageNode<NN> {
type Output = Image;
fn eval(self, input: Image) -> Self::Output {
let node = self.0.eval(());
use graph_craft::document::*;
let identifier = NodeIdentifier {
name: std::borrow::Cow::Owned(node),
types: std::borrow::Cow::Borrowed(&[]),
};
let network = NodeNetwork {
inputs: vec![0],
output: 0,
disabled: vec![],
previous_output: None,
nodes: [(
0,
DocumentNode {
name: "Image filter Node".into(),
inputs: vec![NodeInput::Network],
implementation: DocumentNodeImplementation::Unresolved(identifier),
metadata: DocumentNodeMetadata::default(),
},
)]
.into_iter()
.collect(),
};
let value_network = ValueNode::new(network);
let map_node = MapGpuNode::new(&value_network);
let data = map_node.eval(input.data.clone());
Image { data, ..input }
}
} }

View file

@ -1,54 +1,61 @@
use graphene_core::raster::{Color, Image}; use dyn_any::{DynAny, StaticType};
use graphene_core::quantization::*;
use graphene_core::raster::{Color, ImageFrame};
use graphene_core::Node; use graphene_core::Node;
/// The `GenerateQuantizationNode` encodes the brightness of each channel of the image as an integer number /// The `GenerateQuantizationNode` encodes the brightness of each channel of the image as an integer number
/// sepified by the samples parameter. This node is used to asses the loss of visual information when /// sepified by the samples parameter. This node is used to asses the loss of visual information when
/// quantizing the image using different fit functions. /// quantizing the image using different fit functions.
pub struct GenerateQuantizationNode<N: Node<(), Output = u32>, M: Node<(), Output = u32>> { pub struct GenerateQuantizationNode<N, M> {
samples: N, samples: N,
function: M, function: M,
} }
#[node_macro::node_fn(GenerateQuantizationNode)] #[node_macro::node_fn(GenerateQuantizationNode)]
fn generate_quantization_fn(image: Image, samples: u32, function: u32) -> Image { fn generate_quantization_fn(image_frame: ImageFrame, samples: u32, function: u32) -> [Quantization; 4] {
// Scale the input image, this can be removed by adding an extra parameter to the fit function. let image = image_frame.image;
let max_energy = 16380.;
let data: Vec<f64> = image.data.iter().flat_map(|x| vec![x.r() as f64, x.g() as f64, x.b() as f64]).collect(); let len = image.data.len().min(10000);
let data: Vec<f64> = data.iter().map(|x| x * max_energy).collect(); let mut channels: Vec<_> = (0..4).map(|_| Vec::with_capacity(image.data.len())).collect();
image
.data
.iter()
.enumerate()
.filter(|(i, _)| i % (image.data.len() / len) == 0)
.map(|(_, x)| vec![x.r() as f64, x.g() as f64, x.b() as f64, x.a() as f64])
.for_each(|x| x.into_iter().enumerate().for_each(|(i, value)| channels[i].push(value)));
log::info!("Quantizing {} samples", channels[0].len());
log::info!("In {} channels", channels.len());
let quantization: Vec<Quantization> = channels.into_iter().map(|x| generate_quantization_per_channel(x, samples)).collect();
core::array::from_fn(|i| quantization[i].clone())
}
fn generate_quantization_per_channel(data: Vec<f64>, samples: u32) -> Quantization {
let mut dist = autoquant::integrate_distribution(data); let mut dist = autoquant::integrate_distribution(data);
autoquant::drop_duplicates(&mut dist); autoquant::drop_duplicates(&mut dist);
let dist = autoquant::normalize_distribution(dist.as_slice()); let dist = autoquant::normalize_distribution(dist.as_slice());
let max = dist.last().unwrap().0; let max = dist.last().unwrap().0;
let linear = Box::new(autoquant::SimpleFitFn { /*let linear = Box::new(autoquant::SimpleFitFn {
function: move |x| x / max, function: move |x| x / max,
inverse: move |x| x * max, inverse: move |x| x * max,
name: "identity", name: "identity",
}); });*/
let best = match function {
0 => linear as Box<dyn autoquant::FitFn>, let linear = Quantization {
1 => linear as Box<dyn autoquant::FitFn>, fn_index: 0,
2 => Box::new(autoquant::models::OptimizedLog::new(dist, 20)) as Box<dyn autoquant::FitFn>, a: max as f32,
_ => linear as Box<dyn autoquant::FitFn>, b: 0.,
c: 0.,
d: 0.,
}; };
let log_fit = autoquant::models::OptimizedLog::new(dist, samples as u64);
let roundtrip = |sample: f32| -> f32 { let parameters = log_fit.parameters();
let encoded = autoquant::encode(sample as f64 * max_energy, best.as_ref(), samples); let log_fit = Quantization {
let decoded = autoquant::decode(encoded, best.as_ref(), samples) / max_energy; fn_index: 1,
log::trace!("{} enc: {} dec: {}", sample, encoded, decoded); a: parameters[0] as f32,
decoded as f32 b: parameters[1] as f32,
c: parameters[2] as f32,
d: parameters[3] as f32,
}; };
log_fit
let new_data = image
.data
.iter()
.map(|c| {
let r = roundtrip(c.r());
let g = roundtrip(c.g());
let b = roundtrip(c.b());
let a = c.a();
Color::from_rgbaf32_unchecked(r, g, b, a)
})
.collect();
Image { data: new_data, ..image }
} }

View file

@ -158,6 +158,6 @@ mod tests {
let compiler = Compiler {}; let compiler = Compiler {};
let protograph = compiler.compile_single(network, true).expect("Graph should be generated"); let protograph = compiler.compile_single(network, true).expect("Graph should be generated");
let exec = DynamicExecutor::new(protograph).map(|e| panic!("The network should not type check: {:#?}", e)).unwrap_err(); let _exec = DynamicExecutor::new(protograph).map(|e| panic!("The network should not type check: {:#?}", e)).unwrap_err();
} }
} }

View file

@ -24,6 +24,8 @@ use crate::executor::NodeContainer;
use dyn_any::StaticType; use dyn_any::StaticType;
use graphene_core::quantization::QuantizationChannels;
macro_rules! construct_node { macro_rules! construct_node {
($args: ident, $path:ty, [$($type:tt),*]) => {{ ($args: ident, $path:ty, [$($type:tt),*]) => {{
let mut args: Vec<TypeErasedPinnedRef<'static>> = $args.clone(); let mut args: Vec<TypeErasedPinnedRef<'static>> = $args.clone();
@ -133,6 +135,8 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
register_node!(graphene_core::ops::AddParameterNode<_>, input: f64, params: [&f64]), register_node!(graphene_core::ops::AddParameterNode<_>, input: f64, params: [&f64]),
register_node!(graphene_core::ops::AddParameterNode<_>, input: &f64, params: [&f64]), register_node!(graphene_core::ops::AddParameterNode<_>, input: &f64, params: [&f64]),
register_node!(graphene_core::ops::SomeNode, input: ImageFrame, params: []), register_node!(graphene_core::ops::SomeNode, input: ImageFrame, params: []),
#[cfg(feature = "gpu")]
register_node!(graphene_std::executor::MapGpuSingleImageNode<_>, input: Image, params: [String]),
vec![( vec![(
NodeIdentifier::new("graphene_core::structural::ComposeNode<_, _, _>"), NodeIdentifier::new("graphene_core::structural::ComposeNode<_, _, _>"),
|args| { |args| {
@ -292,9 +296,23 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
}, },
NodeIOTypes::new(concrete!(Image), concrete!(&Image), vec![]), NodeIOTypes::new(concrete!(Image), concrete!(&Image), vec![]),
), ),
(
NodeIdentifier::new("graphene_std::memo::CacheNode"),
|_| {
let node: CacheNode<QuantizationChannels> = graphene_std::memo::CacheNode::new();
let any = DynAnyRefNode::new(node);
any.into_type_erased()
},
NodeIOTypes::new(concrete!(QuantizationChannels), concrete!(&QuantizationChannels), vec![]),
),
], ],
register_node!(graphene_core::structural::ConsNode<_, _>, input: Image, params: [&str]), register_node!(graphene_core::structural::ConsNode<_, _>, input: Image, params: [&str]),
register_node!(graphene_std::raster::ImageFrameNode<_>, input: Image, params: [DAffine2]), register_node!(graphene_std::raster::ImageFrameNode<_>, input: Image, params: [DAffine2]),
#[cfg(feature = "quantization")]
register_node!(graphene_std::quantization::GenerateQuantizationNode<_, _>, input: ImageFrame, params: [u32, u32]),
raster_node!(graphene_core::quantization::QuantizeNode<_>, params: [QuantizationChannels]),
raster_node!(graphene_core::quantization::DeQuantizeNode<_>, params: [QuantizationChannels]),
register_node!(graphene_core::ops::CloneNode<_>, input: &QuantizationChannels, params: []),
register_node!(graphene_core::vector::TransformNode<_, _, _, _>, input: VectorData, params: [DVec2, f64, DVec2, DVec2]), register_node!(graphene_core::vector::TransformNode<_, _, _, _>, input: VectorData, params: [DVec2, f64, DVec2, DVec2]),
register_node!(graphene_core::vector::SetFillNode<_, _, _, _, _, _, _>, input: VectorData, params: [ graphene_core::vector::style::FillType, graphene_core::Color, graphene_core::vector::style::GradientType, DVec2, DVec2, DAffine2, Vec<(f64, Option<graphene_core::Color>)>]), register_node!(graphene_core::vector::SetFillNode<_, _, _, _, _, _, _>, input: VectorData, params: [ graphene_core::vector::style::FillType, graphene_core::Color, graphene_core::vector::style::GradientType, DVec2, DVec2, DAffine2, Vec<(f64, Option<graphene_core::Color>)>]),
register_node!(graphene_core::vector::SetStrokeNode<_, _, _, _, _, _, _>, input: VectorData, params: [graphene_core::Color, f64, Vec<f32>, f64, graphene_core::vector::style::LineCap, graphene_core::vector::style::LineJoin, f64]), register_node!(graphene_core::vector::SetStrokeNode<_, _, _, _, _, _, _>, input: VectorData, params: [graphene_core::Color, f64, Vec<f32>, f64, graphene_core::vector::style::LineCap, graphene_core::vector::style::LineJoin, f64]),
@ -304,158 +322,6 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
input: graphene_core::vector::bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>, input: graphene_core::vector::bezier_rs::Subpath<graphene_core::uuid::ManipulatorGroupId>,
params: [] params: []
), ),
/*
(NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]), |_proto_node, stack| {
stack.push_fn(|_nodes| {
let image = FnNode::new(|s: &str| graphene_std::raster::image_node::<&str>().eval(s).unwrap());
let node: DynAnyNode<_, &str, _, _> = DynAnyNode::new(image);
node.into_type_erased()
})
}),
(NodeIdentifier::new("graphene_std::raster::ExportImageNode", &[concrete!("&str")]), |proto_node, stack| {
stack.push_fn(|nodes| {
let pre_node = nodes.get(proto_node.input.unwrap_node() as usize).unwrap();
let image = FnNode::new(|input: (Image, &str)| graphene_std::raster::export_image_node().eval(input).unwrap());
let node: DynAnyNode<_, (Image, &str), _, _> = DynAnyNode::new(image);
let node = (pre_node).then(node);
node.into_type_erased()
})
}),
(
NodeIdentifier::new("graphene_core::structural::ConsNode", &[concrete!("Image"), concrete!("&str")]),
|proto_node, stack| {
let node_id = proto_node.input.unwrap_node() as usize;
if let ConstructionArgs::Nodes(cons_node_arg) = proto_node.construction_args {
stack.push_fn(move |nodes| {
let pre_node = nodes.get(node_id).unwrap();
let cons_node_arg = nodes.get(cons_node_arg[0] as usize).unwrap();
let cons_node = ConsNode::new(DowncastNode::<_, &str>::new(cons_node_arg));
let node: DynAnyNode<_, Image, _, _> = DynAnyNode::new(cons_node);
let node = (pre_node).then(node);
node.into_type_erased()
})
} else {
unimplemented!()
}
},
),
(NodeIdentifier::new("graphene_std::vector::generator_nodes::UnitCircleGenerator", &[]), |_proto_node, stack| {
stack.push_fn(|_nodes| DynAnyNode::new(graphene_std::vector::generator_nodes::UnitCircleGenerator).into_type_erased())
}),
(NodeIdentifier::new("graphene_std::vector::generator_nodes::UnitSquareGenerator", &[]), |_proto_node, stack| {
stack.push_fn(|_nodes| DynAnyNode::new(graphene_std::vector::generator_nodes::UnitSquareGenerator).into_type_erased())
}),
(NodeIdentifier::new("graphene_std::vector::generator_nodes::BlitSubpath", &[]), |proto_node, stack| {
stack.push_fn(move |nodes| {
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("BlitSubpath without subpath input") };
let value_node = nodes.get(construction_nodes[0] as usize).unwrap();
let input_node: DowncastBothNode<_, (), Subpath> = DowncastBothNode::new(value_node);
let node = DynAnyNode::new(graphene_std::vector::generator_nodes::BlitSubpath::new(input_node));
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(node).into_type_erased()
} else {
node.into_type_erased()
}
})
}),
(NodeIdentifier::new("graphene_std::vector::generator_nodes::TransformSubpathNode", &[]), |proto_node, stack| {
stack.push_fn(move |nodes| {
let ConstructionArgs::Nodes(construction_nodes) = proto_node.construction_args else { unreachable!("TransformSubpathNode without subpath input") };
let translate_node: DowncastBothNode<_, (), DVec2> = DowncastBothNode::new(nodes.get(construction_nodes[0] as usize).unwrap());
let rotate_node: DowncastBothNode<_, (), f64> = DowncastBothNode::new(nodes.get(construction_nodes[1] as usize).unwrap());
let scale_node: DowncastBothNode<_, (), DVec2> = DowncastBothNode::new(nodes.get(construction_nodes[2] as usize).unwrap());
let shear_node: DowncastBothNode<_, (), DVec2> = DowncastBothNode::new(nodes.get(construction_nodes[3] as usize).unwrap());
let node = DynAnyNode::new(graphene_std::vector::generator_nodes::TransformSubpathNode::new(translate_node, rotate_node, scale_node, shear_node));
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(node).into_type_erased()
} else {
node.into_type_erased()
}
})
}),
#[cfg(feature = "gpu")]
(
NodeIdentifier::new("graphene_std::executor::MapGpuNode", &[concrete!("&TypeErasedNode"), concrete!("Color"), concrete!("Color")]),
|proto_node, stack| {
if let ConstructionArgs::Nodes(operation_node_id) = proto_node.construction_args {
stack.push_fn(move |nodes| {
info!("Map image Depending upon id {:?}", operation_node_id);
let operation_node = nodes.get(operation_node_id[0] as usize).unwrap();
let input_node: DowncastBothNode<_, (), &graph_craft::document::NodeNetwork> = DowncastBothNode::new(operation_node);
let map_node: graphene_std::executor::MapGpuNode<_, Vec<u32>, u32, u32> = graphene_std::executor::MapGpuNode::new(input_node);
let map_node = DynAnyNode::new(map_node);
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(map_node).into_type_erased()
} else {
map_node.into_type_erased()
}
})
} else {
unimplemented!()
}
},
),
#[cfg(feature = "gpu")]
(
NodeIdentifier::new("graphene_std::executor::MapGpuSingleImageNode", &[concrete!("&TypeErasedNode")]),
|proto_node, stack| {
if let ConstructionArgs::Nodes(operation_node_id) = proto_node.construction_args {
stack.push_fn(move |nodes| {
info!("Map image Depending upon id {:?}", operation_node_id);
let operation_node = nodes.get(operation_node_id[0] as usize).unwrap();
let input_node: DowncastBothNode<_, (), String> = DowncastBothNode::new(operation_node);
let map_node = graphene_std::executor::MapGpuSingleImageNode(input_node);
let map_node = DynAnyNode::new(map_node);
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(map_node).into_type_erased()
} else {
map_node.into_type_erased()
}
})
} else {
unimplemented!()
}
},
),
#[cfg(feature = "quantization")]
(
NodeIdentifier::new("graphene_std::quantization::GenerateQuantizationNode", &[concrete!("&TypeErasedNode")]),
|proto_node, stack| {
if let ConstructionArgs::Nodes(operation_node_id) = proto_node.construction_args {
stack.push_fn(move |nodes| {
info!("Quantization Depending upon id {:?}", operation_node_id);
let samples_node = nodes.get(operation_node_id[0] as usize).unwrap();
let index_node = nodes.get(operation_node_id[1] as usize).unwrap();
let samples_node: DowncastBothNode<_, (), u32> = DowncastBothNode::new(samples_node);
let index_node: DowncastBothNode<_, (), u32> = DowncastBothNode::new(index_node);
let map_node = graphene_std::quantization::GenerateQuantizationNode::new(samples_node, index_node);
let map_node = DynAnyNode::new(map_node);
if let ProtoNodeInput::Node(node_id) = proto_node.input {
let pre_node = nodes.get(node_id as usize).unwrap();
(pre_node).then(map_node).into_type_erased()
} else {
map_node.into_type_erased()
}
})
} else {
unimplemented!()
}
},
),
<<<<<<< HEAD
*/
]; ];
let mut map: HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>> = HashMap::new(); let mut map: HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>> = HashMap::new();
for (id, c, types) in node_types.into_iter().flatten() { for (id, c, types) in node_types.into_iter().flatten() {
@ -466,101 +332,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
pub static NODE_REGISTRY: Lazy<HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>>> = Lazy::new(|| node_registry()); pub static NODE_REGISTRY: Lazy<HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>>> = Lazy::new(|| node_registry());
/*
#[cfg(test)] #[cfg(test)]
mod protograph_testing { mod protograph_testing {
use borrow_stack::BorrowStack; // TODO: adde tests testing the node registry
use super::*;
#[test]
fn add_values() {
let stack = FixedSizeStack::new(256);
let val_1_protonode = ProtoNode::value(ConstructionArgs::Value(Box::new(2u32)));
constrcut_node(val_1_protonode, &stack);
let val_2_protonode = ProtoNode::value(ConstructionArgs::Value(Box::new(40u32)));
constrcut_node(val_2_protonode, &stack);
let cons_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![1]),
input: ProtoNodeInput::Node(0),
identifier: NodeIdentifier::new("graphene_core::structural::ConsNode", &[concrete!("u32"), concrete!("u32")]),
};
constrcut_node(cons_protonode, &stack);
let add_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::Node(2),
identifier: NodeIdentifier::new("graphene_core::ops::AddNode", &[concrete!("u32"), concrete!("u32")]),
};
constrcut_node(add_protonode, &stack);
let result = unsafe { stack.get()[3].eval(Box::new(())) };
let val = *dyn_any::downcast::<u32>(result).unwrap();
assert_eq!(val, 42);
}
#[test]
fn grayscale_color() {
let stack = FixedSizeStack::new(256);
let val_protonode = ProtoNode::value(ConstructionArgs::Value(Box::new(Color::from_rgb8(10, 20, 30))));
constrcut_node(val_protonode, &stack);
let grayscale_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::Node(0),
identifier: NodeIdentifier::new("graphene_core::raster::GrayscaleColorNode", &[]),
};
constrcut_node(grayscale_protonode, &stack);
let result = unsafe { stack.get()[1].eval(Box::new(())) };
let val = *dyn_any::downcast::<Color>(result).unwrap();
assert_eq!(val, Color::from_rgb8(20, 20, 20));
}
#[test]
fn load_image() {
let stack = FixedSizeStack::new(256);
let image_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::None,
identifier: NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]),
};
constrcut_node(image_protonode, &stack);
let result = unsafe { stack.get()[0].eval(Box::new("../gstd/test-image-1.png")) };
let image = *dyn_any::downcast::<Image>(result).unwrap();
assert_eq!(image.height, 240);
}
#[test]
fn grayscale_map_image() {
let stack = FixedSizeStack::new(256);
let image_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::None,
identifier: NodeIdentifier::new("graphene_std::raster::ImageNode", &[concrete!("&str")]),
};
constrcut_node(image_protonode, &stack);
let grayscale_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![]),
input: ProtoNodeInput::None,
identifier: NodeIdentifier::new("graphene_core::raster::GrayscaleColorNode", &[]),
};
constrcut_node(grayscale_protonode, &stack);
let image_map_protonode = ProtoNode {
construction_args: ConstructionArgs::Nodes(vec![1]),
input: ProtoNodeInput::Node(0),
identifier: NodeIdentifier::new("graphene_std::raster::MapImageNode", &[]),
};
constrcut_node(image_map_protonode, &stack);
let result = unsafe { stack.get()[2].eval(Box::new("../gstd/test-image-1.png")) };
let image = *dyn_any::downcast::<Image>(result).unwrap();
assert!(!image.data.iter().any(|c| c.r() != c.b() || c.b() != c.g()));
}
} }
*/

View file

@ -17,6 +17,7 @@ pub struct GpuExecutor<'a, I: StaticTypeSized, O> {
impl<'a, I: StaticTypeSized, O> GpuExecutor<'a, I, O> { impl<'a, I: StaticTypeSized, O> GpuExecutor<'a, I, O> {
pub fn new(context: Context, shader: Cow<'a, [u32]>, entry_point: String) -> anyhow::Result<Self> { pub fn new(context: Context, shader: Cow<'a, [u32]>, entry_point: String) -> anyhow::Result<Self> {
log::info!("Creating executor");
Ok(Self { Ok(Self {
context, context,
entry_point, entry_point,
@ -28,6 +29,7 @@ impl<'a, I: StaticTypeSized, O> GpuExecutor<'a, I, O> {
impl<'a, I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Sync + Pod> Executor for GpuExecutor<'a, I, O> { impl<'a, I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Sync + Pod> Executor for GpuExecutor<'a, I, O> {
fn execute<'i, 's: 'i>(&'s self, input: Any<'i>) -> Result<Any<'i>, Box<dyn std::error::Error>> { fn execute<'i, 's: 'i>(&'s self, input: Any<'i>) -> Result<Any<'i>, Box<dyn std::error::Error>> {
log::info!("Executing shader");
let input = dyn_any::downcast::<Vec<I>>(input).expect("Wrong input type"); let input = dyn_any::downcast::<Vec<I>>(input).expect("Wrong input type");
let context = &self.context; let context = &self.context;
let future = execute_shader(context.device.clone(), context.queue.clone(), self.shader.to_vec(), *input, self.entry_point.clone()); let future = execute_shader(context.device.clone(), context.queue.clone(), self.shader.to_vec(), *input, self.entry_point.clone());
@ -40,6 +42,11 @@ impl<'a, I: StaticTypeSized + Sync + Pod + Send, O: StaticTypeSized + Send + Syn
async fn execute_shader<I: Pod + Send + Sync, O: Pod + Send + Sync>(device: Arc<wgpu::Device>, queue: Arc<wgpu::Queue>, shader: Vec<u32>, data: Vec<I>, entry_point: String) -> Option<Vec<O>> { async fn execute_shader<I: Pod + Send + Sync, O: Pod + Send + Sync>(device: Arc<wgpu::Device>, queue: Arc<wgpu::Queue>, shader: Vec<u32>, data: Vec<I>, entry_point: String) -> Option<Vec<O>> {
// Loads the shader from WGSL // Loads the shader from WGSL
dbg!(&shader);
//write shader to file
use std::io::Write;
let mut file = std::fs::File::create("/tmp/shader.spv").unwrap();
file.write_all(bytemuck::cast_slice(&shader)).unwrap();
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor { let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None, label: None,
source: wgpu::ShaderSource::SpirV(shader.into()), source: wgpu::ShaderSource::SpirV(shader.into()),
@ -122,7 +129,7 @@ async fn execute_shader<I: Pod + Send + Sync, O: Pod + Send + Sync>(device: Arc<
cpass.set_pipeline(&compute_pipeline); cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, &bind_group, &[]); cpass.set_bind_group(0, &bind_group, &[]);
cpass.insert_debug_marker("compute node network evaluation"); cpass.insert_debug_marker("compute node network evaluation");
cpass.dispatch_workgroups(data.len() as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed cpass.dispatch_workgroups(data.len().min(65535) as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
} }
// Sets adds copy operation to command encoder. // Sets adds copy operation to command encoder.
// Will copy data from storage buffer on GPU to staging buffer on CPU. // Will copy data from storage buffer on GPU to staging buffer on CPU.