Implement experimental WebGPU support (#1238)

* Web gpu execution MVP

Ready infrastructure for wgpu experimentation

Start implementing simple gpu test case

Fix Extract Node not working with nested networks

Convert inputs for extracted node to network inputs

Fix missing cors headers

Feature gate gcore to make it once again no-std compatible

Add skeleton structure gpu shader

Work on gpu node graph output saving

Fix Get and Set nodes

Fix storage nodes

Fix shader construction errors -> spirv errors

Add unsafe version

Add once cell node

Web gpu execution MVP
This commit is contained in:
Dennis Kobert 2023-05-27 19:27:46 +02:00 committed by Keavon Chambers
parent 4bd9fbd073
commit 0586d52f3a
33 changed files with 1080 additions and 239 deletions

292
Cargo.lock generated
View file

@ -8,6 +8,15 @@ version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3"
[[package]]
name = "addr2line"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97"
dependencies = [
"gimli",
]
[[package]] [[package]]
name = "adler" name = "adler"
version = "1.0.2" version = "1.0.2"
@ -145,7 +154,7 @@ version = "0.37.2+1.3.238"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28bf19c1f0a470be5fbf7522a308a05df06610252c5bcf5143e1b23f629a9a03" checksum = "28bf19c1f0a470be5fbf7522a308a05df06610252c5bcf5143e1b23f629a9a03"
dependencies = [ dependencies = [
"libloading", "libloading 0.7.4",
] ]
[[package]] [[package]]
@ -166,7 +175,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c3d816ce6f0e2909a96830d6911c2aff044370b1ef92d7f267b43bae5addedd" checksum = "2c3d816ce6f0e2909a96830d6911c2aff044370b1ef92d7f267b43bae5addedd"
dependencies = [ dependencies = [
"atk-sys", "atk-sys",
"bitflags", "bitflags 1.3.2",
"glib", "glib",
"libc", "libc",
] ]
@ -238,7 +247,7 @@ checksum = "b70caf9f1b0c045f7da350636435b775a9733adf2df56e8aa2a29210fbc335d4"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
"bitflags", "bitflags 1.3.2",
"bytes", "bytes",
"futures-util", "futures-util",
"http", "http",
@ -279,6 +288,21 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "backtrace"
version = "0.3.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide 0.6.2",
"object",
"rustc-demangle",
]
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.13.1" version = "0.13.1"
@ -345,6 +369,12 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6776fc96284a0bb647b615056fc496d1fe1644a7ab01829818a6d91cae888b84"
[[package]] [[package]]
name = "block" name = "block"
version = "0.1.6" version = "0.1.6"
@ -448,7 +478,7 @@ version = "0.15.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c76ee391b03d35510d9fa917357c7f1855bd9a6659c95a1b392e33f49b3369bc" checksum = "c76ee391b03d35510d9fa917357c7f1855bd9a6659c95a1b392e33f49b3369bc"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"cairo-sys-rs", "cairo-sys-rs",
"glib", "glib",
"libc", "libc",
@ -523,12 +553,6 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e"
[[package]] [[package]]
name = "chrono" name = "chrono"
version = "0.4.24" version = "0.4.24"
@ -550,7 +574,7 @@ version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f425db7937052c684daec3bd6375c8abe2d146dca4b8b143d6db777c39138f3a" checksum = "f425db7937052c684daec3bd6375c8abe2d146dca4b8b143d6db777c39138f3a"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"block", "block",
"cocoa-foundation", "cocoa-foundation",
"core-foundation", "core-foundation",
@ -566,7 +590,7 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "931d3837c286f56e3c58423ce4eba12d08db2374461a785c86f672b08b5650d6" checksum = "931d3837c286f56e3c58423ce4eba12d08db2374461a785c86f672b08b5650d6"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"block", "block",
"core-foundation", "core-foundation",
"core-graphics-types", "core-graphics-types",
@ -602,6 +626,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "com-rs"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf43edc576402991846b093a7ca18a3477e0ef9c588cde84964b5d3e43016642"
[[package]] [[package]]
name = "combine" name = "combine"
version = "4.6.6" version = "4.6.6"
@ -639,6 +669,7 @@ dependencies = [
"serde_json", "serde_json",
"tempfile", "tempfile",
"tokio", "tokio",
"tower-http",
] ]
[[package]] [[package]]
@ -679,7 +710,7 @@ version = "0.22.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb" checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"core-foundation", "core-foundation",
"core-graphics-types", "core-graphics-types",
"foreign-types", "foreign-types",
@ -692,7 +723,7 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a68b68b3446082644c91ac778bf50cd4104bfb002b5a6a7c44cca5a2c70788b" checksum = "3a68b68b3446082644c91ac778bf50cd4104bfb002b5a6a7c44cca5a2c70788b"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"core-foundation", "core-foundation",
"foreign-types", "foreign-types",
"libc", "libc",
@ -844,12 +875,12 @@ dependencies = [
[[package]] [[package]]
name = "d3d12" name = "d3d12"
version = "0.5.0" version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "827914e1f53b1e0e025ecd3d967a7836b7bcb54520f90e21ef8df7b4d88a2759" checksum = "d8f0de2f5a8e7bd4a9eec0e3c781992a4ce1724f68aec7d7a3715344de8b39da"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"libloading", "libloading 0.7.4",
"winapi", "winapi",
] ]
@ -1127,7 +1158,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743"
dependencies = [ dependencies = [
"crc32fast", "crc32fast",
"miniz_oxide", "miniz_oxide 0.7.1",
] ]
[[package]] [[package]]
@ -1294,7 +1325,7 @@ version = "0.15.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6e05c1f572ab0e1f15be94217f0dc29088c248b14f792a5ff0af0d84bcda9e8" checksum = "a6e05c1f572ab0e1f15be94217f0dc29088c248b14f792a5ff0af0d84bcda9e8"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"cairo-rs", "cairo-rs",
"gdk-pixbuf", "gdk-pixbuf",
"gdk-sys", "gdk-sys",
@ -1310,7 +1341,7 @@ version = "0.15.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad38dd9cc8b099cceecdf41375bb6d481b1b5a7cd5cd603e10a69a9383f8619a" checksum = "ad38dd9cc8b099cceecdf41375bb6d481b1b5a7cd5cd603e10a69a9383f8619a"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"gdk-pixbuf-sys", "gdk-pixbuf-sys",
"gio", "gio",
"glib", "glib",
@ -1405,13 +1436,19 @@ dependencies = [
"wasi 0.11.0+wasi-snapshot-preview1", "wasi 0.11.0+wasi-snapshot-preview1",
] ]
[[package]]
name = "gimli"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
[[package]] [[package]]
name = "gio" name = "gio"
version = "0.15.12" version = "0.15.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68fdbc90312d462781a395f7a16d96a2b379bb6ef8cd6310a2df272771c4283b" checksum = "68fdbc90312d462781a395f7a16d96a2b379bb6ef8cd6310a2df272771c4283b"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"futures-channel", "futures-channel",
"futures-core", "futures-core",
"futures-io", "futures-io",
@ -1452,7 +1489,7 @@ version = "0.15.12"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edb0306fbad0ab5428b0ca674a23893db909a98582969c9b537be4ced78c505d" checksum = "edb0306fbad0ab5428b0ca674a23893db909a98582969c9b537be4ced78c505d"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"futures-channel", "futures-channel",
"futures-core", "futures-core",
"futures-executor", "futures-executor",
@ -1512,9 +1549,9 @@ dependencies = [
[[package]] [[package]]
name = "glow" name = "glow"
version = "0.11.2" version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8bd5877156a19b8ac83a29b2306fe20537429d318f3ff0a1a2119f8d9c61919" checksum = "4e007a07a24de5ecae94160f141029e9a347282cfe25d1d58d85d845cf3130f1"
dependencies = [ dependencies = [
"js-sys", "js-sys",
"slotmap", "slotmap",
@ -1539,7 +1576,7 @@ version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fc59e5f710e310e76e6707f86c561dd646f69a8876da9131703b2f717de818d" checksum = "7fc59e5f710e310e76e6707f86c561dd646f69a8876da9131703b2f717de818d"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"gpu-alloc-types", "gpu-alloc-types",
] ]
@ -1549,7 +1586,20 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54804d0d6bc9d7f26db4eaec1ad10def69b599315f487d32c334a80d1efe67a5" checksum = "54804d0d6bc9d7f26db4eaec1ad10def69b599315f487d32c334a80d1efe67a5"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
]
[[package]]
name = "gpu-allocator"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce95f9e2e11c2c6fadfce42b5af60005db06576f231f5c92550fdded43c423e8"
dependencies = [
"backtrace",
"log",
"thiserror",
"winapi",
"windows 0.44.0",
] ]
[[package]] [[package]]
@ -1570,7 +1620,7 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b0c02e1ba0bdb14e965058ca34e09c020f8e507a760df1121728e0aef68d57a" checksum = "0b0c02e1ba0bdb14e965058ca34e09c020f8e507a760df1121728e0aef68d57a"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"gpu-descriptor-types", "gpu-descriptor-types",
"hashbrown", "hashbrown",
] ]
@ -1581,7 +1631,7 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "363e3677e55ad168fef68cf9de3a4a310b53124c5e784c53a1d70e92d23f2126" checksum = "363e3677e55ad168fef68cf9de3a4a310b53124c5e784c53a1d70e92d23f2126"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
] ]
[[package]] [[package]]
@ -1723,7 +1773,7 @@ name = "graphite-editor"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"bezier-rs", "bezier-rs",
"bitflags", "bitflags 1.3.2",
"borrow_stack", "borrow_stack",
"derivative", "derivative",
"dyn-any", "dyn-any",
@ -1785,7 +1835,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92e3004a2d5d6d8b5057d2b57b3712c9529b62e82c77f25c1fecde1fd5c23bd0" checksum = "92e3004a2d5d6d8b5057d2b57b3712c9529b62e82c77f25c1fecde1fd5c23bd0"
dependencies = [ dependencies = [
"atk", "atk",
"bitflags", "bitflags 1.3.2",
"cairo-rs", "cairo-rs",
"field-offset", "field-offset",
"futures-channel", "futures-channel",
@ -1870,6 +1920,21 @@ dependencies = [
"ahash 0.7.6", "ahash 0.7.6",
] ]
[[package]]
name = "hassle-rs"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1397650ee315e8891a0df210707f0fc61771b0cc518c3023896064c5407cb3b0"
dependencies = [
"bitflags 1.3.2",
"com-rs",
"libc",
"libloading 0.7.4",
"thiserror",
"widestring",
"winapi",
]
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.3.3" version = "0.3.3"
@ -1957,6 +2022,12 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21dec9db110f5f872ed9699c3ecf50cf16f423502706ba5c72462e28d3157573" checksum = "21dec9db110f5f872ed9699c3ecf50cf16f423502706ba5c72462e28d3157573"
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]] [[package]]
name = "httparse" name = "httparse"
version = "1.8.0" version = "1.8.0"
@ -2189,7 +2260,7 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf053e7843f2812ff03ef5afe34bb9c06ffee120385caad4f6b9967fcd37d41c" checksum = "bf053e7843f2812ff03ef5afe34bb9c06ffee120385caad4f6b9967fcd37d41c"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"glib", "glib",
"javascriptcore-rs-sys", "javascriptcore-rs-sys",
] ]
@ -2253,7 +2324,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c2352bd1d0bceb871cb9d40f24360c8133c11d7486b68b5381c1dd1a32015e3" checksum = "8c2352bd1d0bceb871cb9d40f24360c8133c11d7486b68b5381c1dd1a32015e3"
dependencies = [ dependencies = [
"libc", "libc",
"libloading", "libloading 0.7.4",
"pkg-config", "pkg-config",
] ]
@ -2321,6 +2392,16 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "libloading"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d580318f95776505201b28cf98eb1fa5e4be3b689633ba6a3e6cd880ff22d8cb"
dependencies = [
"cfg-if",
"windows-sys 0.48.0",
]
[[package]] [[package]]
name = "libm" name = "libm"
version = "0.2.6" version = "0.2.6"
@ -2484,7 +2565,7 @@ version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de11355d1f6781482d027a3b4d4de7825dcedb197bf573e0596d00008402d060" checksum = "de11355d1f6781482d027a3b4d4de7825dcedb197bf573e0596d00008402d060"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"block", "block",
"core-graphics-types", "core-graphics-types",
"foreign-types", "foreign-types",
@ -2498,6 +2579,15 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "miniz_oxide"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa"
dependencies = [
"adler",
]
[[package]] [[package]]
name = "miniz_oxide" name = "miniz_oxide"
version = "0.7.1" version = "0.7.1"
@ -2522,12 +2612,12 @@ dependencies = [
[[package]] [[package]]
name = "naga" name = "naga"
version = "0.10.0" version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "262d2840e72dbe250e8cf2f522d080988dfca624c4112c096238a4845f591707" checksum = "94d3edd593521f4a1dfd9b25193ed0224764572905f013d30ca5fbb85e010876"
dependencies = [ dependencies = [
"bit-set", "bit-set",
"bitflags", "bitflags 1.3.2",
"codespan-reporting", "codespan-reporting",
"hexf-parse", "hexf-parse",
"indexmap", "indexmap",
@ -2609,7 +2699,7 @@ version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2032c77e030ddee34a6787a64166008da93f6a352b629261d0fee232b8742dd4" checksum = "2032c77e030ddee34a6787a64166008da93f6a352b629261d0fee232b8742dd4"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"jni-sys", "jni-sys",
"ndk-sys", "ndk-sys",
"num_enum", "num_enum",
@ -2842,6 +2932,15 @@ dependencies = [
"objc", "objc",
] ]
[[package]]
name = "object"
version = "0.30.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.17.1" version = "1.17.1"
@ -2864,7 +2963,7 @@ version = "0.10.52"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01b8574602df80f7b85fdfc5392fa884a4e3b3f4f35402c070ab34c3d3f78d56" checksum = "01b8574602df80f7b85fdfc5392fa884a4e3b3f4f35402c070ab34c3d3f78d56"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"cfg-if", "cfg-if",
"foreign-types", "foreign-types",
"libc", "libc",
@ -2935,7 +3034,7 @@ version = "0.15.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e4045548659aee5313bde6c582b0d83a627b7904dd20dc2d9ef0895d414e4f" checksum = "22e4045548659aee5313bde6c582b0d83a627b7904dd20dc2d9ef0895d414e4f"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"glib", "glib",
"libc", "libc",
"once_cell", "once_cell",
@ -3161,11 +3260,11 @@ version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aaeebc51f9e7d2c150d3f3bfeb667f2aa985db5ef1e3d212847bdedb488beeaa" checksum = "aaeebc51f9e7d2c150d3f3bfeb667f2aa985db5ef1e3d212847bdedb488beeaa"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"crc32fast", "crc32fast",
"fdeflate", "fdeflate",
"flate2", "flate2",
"miniz_oxide", "miniz_oxide 0.7.1",
] ]
[[package]] [[package]]
@ -3379,7 +3478,7 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
] ]
[[package]] [[package]]
@ -3388,7 +3487,7 @@ version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
] ]
[[package]] [[package]]
@ -3447,9 +3546,9 @@ dependencies = [
[[package]] [[package]]
name = "renderdoc-sys" name = "renderdoc-sys"
version = "0.7.1" version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1382d1f0a252c4bf97dc20d979a2fdd05b024acd7c2ed0f7595d7817666a157" checksum = "216080ab382b992234dda86873c18d4c48358f5cfcb70fd693d7f6f2131b628b"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
@ -3539,10 +3638,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "300a51053b1cb55c80b7a9fde4120726ddf25ca241a1cbb926626f62fb136bff" checksum = "300a51053b1cb55c80b7a9fde4120726ddf25ca241a1cbb926626f62fb136bff"
dependencies = [ dependencies = [
"base64 0.13.1", "base64 0.13.1",
"bitflags", "bitflags 1.3.2",
"serde", "serde",
] ]
[[package]]
name = "rustc-demangle"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "1.1.0" version = "1.1.0"
@ -3564,7 +3669,7 @@ version = "0.37.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0661814f891c57c930a610266415528da53c4933e6dea5fb350cbfe048a9ece" checksum = "a0661814f891c57c930a610266415528da53c4933e6dea5fb350cbfe048a9ece"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"errno", "errno",
"io-lifetimes", "io-lifetimes",
"libc", "libc",
@ -3605,7 +3710,7 @@ version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab9e34ecf6900625412355a61bda0bd68099fe674de707c67e5e4aed2c05e489" checksum = "ab9e34ecf6900625412355a61bda0bd68099fe674de707c67e5e4aed2c05e489"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"bytemuck", "bytemuck",
"smallvec", "smallvec",
"ttf-parser", "ttf-parser",
@ -3688,7 +3793,7 @@ version = "2.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a332be01508d814fed64bf28f798a146d73792121129962fdf335bb3c49a4254" checksum = "a332be01508d814fed64bf28f798a146d73792121129962fdf335bb3c49a4254"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"core-foundation", "core-foundation",
"core-foundation-sys", "core-foundation-sys",
"libc", "libc",
@ -3711,7 +3816,7 @@ version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe" checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"cssparser", "cssparser",
"derive_more", "derive_more",
"fxhash", "fxhash",
@ -4018,7 +4123,7 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2b4d76501d8ba387cf0fefbe055c3e0a59891d09f0f995ae4e4b16f6b60f3c0" checksum = "b2b4d76501d8ba387cf0fefbe055c3e0a59891d09f0f995ae4e4b16f6b60f3c0"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"gio", "gio",
"glib", "glib",
"libc", "libc",
@ -4032,7 +4137,7 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "009ef427103fcb17f802871647a7fa6c60cbb654b4c4e4c0ac60a31c5f6dc9cf" checksum = "009ef427103fcb17f802871647a7fa6c60cbb654b4c4e4c0ac60a31c5f6dc9cf"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"gio-sys", "gio-sys",
"glib-sys", "glib-sys",
"gobject-sys", "gobject-sys",
@ -4087,7 +4192,7 @@ version = "0.2.0+1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "246bfa38fe3db3f1dfc8ca5a2cdeb7348c78be2112740cc0ec8ef18b6d94f830" checksum = "246bfa38fe3db3f1dfc8ca5a2cdeb7348c78be2112740cc0ec8ef18b6d94f830"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"num-traits", "num-traits",
] ]
@ -4097,7 +4202,7 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3197bd4c021c2dfc0f9dfb356312c8f7842d972d5545c308ad86422c2e2d3e66" checksum = "3197bd4c021c2dfc0f9dfb356312c8f7842d972d5545c308ad86422c2e2d3e66"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"glam", "glam",
"num-traits", "num-traits",
"spirv-std-macros", "spirv-std-macros",
@ -4262,7 +4367,7 @@ version = "0.15.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac8e6399427c8494f9849b58694754d7cc741293348a6836b6c8d2c5aa82d8e6" checksum = "ac8e6399427c8494f9849b58694754d7cc741293348a6836b6c8d2c5aa82d8e6"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"cairo-rs", "cairo-rs",
"cc", "cc",
"cocoa", "cocoa",
@ -4788,6 +4893,24 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "tower-http"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658"
dependencies = [
"bitflags 1.3.2",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "tower-layer" name = "tower-layer"
version = "0.3.2" version = "0.3.2"
@ -5073,7 +5196,7 @@ dependencies = [
"heck 0.4.1", "heck 0.4.1",
"indexmap", "indexmap",
"lazy_static", "lazy_static",
"libloading", "libloading 0.7.4",
"objc", "objc",
"parking_lot", "parking_lot",
"proc-macro2", "proc-macro2",
@ -5242,7 +5365,7 @@ version = "0.18.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8f859735e4a452aeb28c6c56a852967a8a76c8eb1cc32dbf931ad28a13d6370" checksum = "b8f859735e4a452aeb28c6c56a852967a8a76c8eb1cc32dbf931ad28a13d6370"
dependencies = [ dependencies = [
"bitflags", "bitflags 1.3.2",
"cairo-rs", "cairo-rs",
"gdk", "gdk",
"gdk-sys", "gdk-sys",
@ -5267,7 +5390,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d76ca6ecc47aeba01ec61e480139dda143796abcae6f83bcddf50d6b5b1dcf3" checksum = "4d76ca6ecc47aeba01ec61e480139dda143796abcae6f83bcddf50d6b5b1dcf3"
dependencies = [ dependencies = [
"atk-sys", "atk-sys",
"bitflags", "bitflags 1.3.2",
"cairo-sys-rs", "cairo-sys-rs",
"gdk-pixbuf-sys", "gdk-pixbuf-sys",
"gdk-sys", "gdk-sys",
@ -5342,15 +5465,17 @@ dependencies = [
[[package]] [[package]]
name = "wgpu" name = "wgpu"
version = "0.14.2" version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81f643110d228fd62a60c5ed2ab56c4d5b3704520bd50561174ec4ec74932937" checksum = "3059ea4ddec41ca14f356833e2af65e7e38c0a8f91273867ed526fb9bafcca95"
dependencies = [ dependencies = [
"arrayvec", "arrayvec",
"cfg-if",
"js-sys", "js-sys",
"log", "log",
"naga", "naga",
"parking_lot", "parking_lot",
"profiling",
"raw-window-handle", "raw-window-handle",
"smallvec", "smallvec",
"static_assertions", "static_assertions",
@ -5364,21 +5489,20 @@ dependencies = [
[[package]] [[package]]
name = "wgpu-core" name = "wgpu-core"
version = "0.14.2" version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6000d1284ef8eec6076fd5544a73125fd7eb9b635f18dceeb829d826f41724ca" checksum = "8f478237b4bf0d5b70a39898a66fa67ca3a007d79f2520485b8b0c3dfc46f8c2"
dependencies = [ dependencies = [
"arrayvec", "arrayvec",
"bit-vec", "bit-vec",
"bitflags", "bitflags 2.3.1",
"cfg_aliases",
"codespan-reporting", "codespan-reporting",
"fxhash",
"log", "log",
"naga", "naga",
"parking_lot", "parking_lot",
"profiling", "profiling",
"raw-window-handle", "raw-window-handle",
"rustc-hash",
"smallvec", "smallvec",
"thiserror", "thiserror",
"web-sys", "web-sys",
@ -5410,26 +5534,28 @@ dependencies = [
[[package]] [[package]]
name = "wgpu-hal" name = "wgpu-hal"
version = "0.14.1" version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cc320a61acb26be4f549c9b1b53405c10a223fbfea363ec39474c32c348d12f" checksum = "41af2ea7d87bd41ad0a37146252d5f7c26490209f47f544b2ee3b3ff34c7732e"
dependencies = [ dependencies = [
"android_system_properties", "android_system_properties",
"arrayvec", "arrayvec",
"ash", "ash",
"bit-set", "bit-set",
"bitflags", "bitflags 2.3.1",
"block", "block",
"core-graphics-types", "core-graphics-types",
"d3d12", "d3d12",
"foreign-types", "foreign-types",
"fxhash",
"glow", "glow",
"gpu-alloc", "gpu-alloc",
"gpu-allocator",
"gpu-descriptor", "gpu-descriptor",
"hassle-rs",
"js-sys", "js-sys",
"khronos-egl", "khronos-egl",
"libloading", "libc",
"libloading 0.8.0",
"log", "log",
"metal", "metal",
"naga", "naga",
@ -5439,6 +5565,7 @@ dependencies = [
"range-alloc", "range-alloc",
"raw-window-handle", "raw-window-handle",
"renderdoc-sys", "renderdoc-sys",
"rustc-hash",
"smallvec", "smallvec",
"thiserror", "thiserror",
"wasm-bindgen", "wasm-bindgen",
@ -5449,11 +5576,13 @@ dependencies = [
[[package]] [[package]]
name = "wgpu-types" name = "wgpu-types"
version = "0.14.1" version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb6b28ef22cac17b9109b25b3bf8c9a103eeb293d7c5f78653979b09140375f6" checksum = "5bd33a976130f03dcdcd39b3810c0c3fc05daf86f0aaf867db14bfb7c4a9a32b"
dependencies = [ dependencies = [
"bitflags", "bitflags 2.3.1",
"js-sys",
"web-sys",
] ]
[[package]] [[package]]
@ -5466,6 +5595,12 @@ dependencies = [
"safe_arch", "safe_arch",
] ]
[[package]]
name = "widestring"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8"
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.3.9" version = "0.3.9"
@ -5524,6 +5659,15 @@ dependencies = [
"windows_x86_64_msvc 0.39.0", "windows_x86_64_msvc 0.39.0",
] ]
[[package]]
name = "windows"
version = "0.44.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e745dab35a0c4c77aa3ce42d595e13d2003d6902d6b08c9ef5fc326d08da12b"
dependencies = [
"windows-targets 0.42.2",
]
[[package]] [[package]]
name = "windows" name = "windows"
version = "0.48.0" version = "0.48.0"

View file

@ -69,6 +69,7 @@ pub struct NodePropertiesContext<'a> {
pub enum NodeImplementation { pub enum NodeImplementation {
ProtoNode(NodeIdentifier), ProtoNode(NodeIdentifier),
DocumentNode(NodeNetwork), DocumentNode(NodeNetwork),
Extract,
} }
impl NodeImplementation { impl NodeImplementation {
@ -718,14 +719,27 @@ fn static_nodes() -> Vec<DocumentNodeType> {
inputs: vec![ inputs: vec![
DocumentInputType::value("Image", TaggedValue::ImageFrame(ImageFrame::empty()), true), DocumentInputType::value("Image", TaggedValue::ImageFrame(ImageFrame::empty()), true),
DocumentInputType { DocumentInputType {
name: "Path", name: "Node",
data_type: FrontendGraphDataType::Text, data_type: FrontendGraphDataType::General,
default: NodeInput::value(TaggedValue::String(String::new()), false), default: NodeInput::value(TaggedValue::DocumentNode(DocumentNode::default()), true),
}, },
], ],
outputs: vec![DocumentOutputType::new("Image", FrontendGraphDataType::Raster)], outputs: vec![DocumentOutputType::new("Image", FrontendGraphDataType::Raster)],
properties: node_properties::gpu_map_properties, properties: node_properties::no_properties,
}, },
DocumentNodeType {
name: "Extract",
category: "Macros",
identifier: NodeImplementation::Extract,
inputs: vec![DocumentInputType {
name: "Node",
data_type: FrontendGraphDataType::General,
default: NodeInput::value(TaggedValue::DocumentNode(DocumentNode::default()), true),
}],
outputs: vec![DocumentOutputType::new("DocumentNode", FrontendGraphDataType::General)],
properties: node_properties::no_properties,
},
#[cfg(feature = "quantization")]
#[cfg(feature = "quantization")] #[cfg(feature = "quantization")]
DocumentNodeType { DocumentNodeType {
name: "Generate Quantization", name: "Generate Quantization",
@ -1156,6 +1170,7 @@ impl DocumentNodeType {
let num_inputs = self.inputs.len(); let num_inputs = self.inputs.len();
let inner_network = match &self.identifier { let inner_network = match &self.identifier {
NodeImplementation::DocumentNode(network) => network.clone(),
NodeImplementation::ProtoNode(ident) => { NodeImplementation::ProtoNode(ident) => {
NodeNetwork { NodeNetwork {
inputs: (0..num_inputs).map(|_| 0).collect(), inputs: (0..num_inputs).map(|_| 0).collect(),
@ -1175,7 +1190,22 @@ impl DocumentNodeType {
..Default::default() ..Default::default()
} }
} }
NodeImplementation::DocumentNode(network) => network.clone(), NodeImplementation::Extract => NodeNetwork {
inputs: (0..num_inputs).map(|_| 0).collect(),
outputs: vec![NodeOutput::new(0, 0)],
nodes: [(
0,
DocumentNode {
name: "ExtractNode".to_string(),
implementation: DocumentNodeImplementation::Extract,
inputs: self.inputs.iter().map(|i| NodeInput::Network(i.default.ty())).collect(),
..Default::default()
},
)]
.into_iter()
.collect(),
..Default::default()
},
}; };
DocumentNodeImplementation::Network(inner_network) DocumentNodeImplementation::Network(inner_network)

View file

@ -1,5 +0,0 @@
[target.wasm32-unknown-unknown]
rustflags = ["-C", "target-feature=+simd128,+atomics,+bulk-memory,+mutable-globals"]
[unstable]
build-std = ["panic_abort", "std"]

View file

@ -0,0 +1,6 @@
[target.wasm32-unknown-unknown]
#rustflags = ["-C", "target-feature=+simd128,+atomics,+bulk-memory,+mutable-globals","--cfg=web_sys_unstable_apis"]
rustflags = ["-C", "target-feature=+simd128","--cfg=web_sys_unstable_apis"]
[unstable]
build-std = ["panic_abort", "std"]

View file

@ -13,7 +13,7 @@ license = "Apache-2.0"
[features] [features]
tauri = ["ron"] tauri = ["ron"]
gpu = ["editor/gpu"] gpu = ["editor/gpu"]
default = [] default = ["gpu"]
[lib] [lib]
crate-type = ["cdylib", "rlib"] crate-type = ["cdylib", "rlib"]
@ -38,7 +38,7 @@ bezier-rs = { path = "../../libraries/bezier-rs" }
[dependencies.web-sys] [dependencies.web-sys]
version = "0.3.4" version = "0.3.4"
features = ['Window'] features = ["Window"]
[dev-dependencies] [dev-dependencies]
wasm-bindgen-test = "0.3.22" wasm-bindgen-test = "0.3.22"

View file

@ -2,22 +2,22 @@ use gpu_compiler_bin_wrapper::CompileRequest;
use gpu_executor::ShaderIO; use gpu_executor::ShaderIO;
use graph_craft::{proto::ProtoNetwork, Type}; use graph_craft::{proto::ProtoNetwork, Type};
pub async fn compile(network: ProtoNetwork, inputs: Vec<Type>, output: Type, io: ShaderIO) -> Result<Shader, reqwest::Error> { pub async fn compile(networks: Vec<ProtoNetwork>, inputs: Vec<Type>, outputs: Vec<Type>, io: ShaderIO) -> Result<Shader, reqwest::Error> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let compile_request = CompileRequest::new(network, inputs.clone(), output.clone(), io.clone()); let compile_request = CompileRequest::new(networks, inputs.clone(), outputs.clone(), io.clone());
let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send(); let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send();
let response = response.await?; let response = response.await?;
response.bytes().await.map(|b| Shader { response.bytes().await.map(|b| Shader {
spirv_binary: b.windows(4).map(|x| u32::from_le_bytes(x.try_into().unwrap())).collect(), spirv_binary: b.chunks(4).map(|x| u32::from_le_bytes(x.try_into().unwrap())).collect(),
input_types: inputs, input_types: inputs,
output_type: output, output_types: outputs,
io, io,
}) })
} }
pub fn compile_sync(network: ProtoNetwork, inputs: Vec<Type>, output: Type, io: ShaderIO) -> Result<Shader, reqwest::Error> { pub fn compile_sync(networks: Vec<ProtoNetwork>, inputs: Vec<Type>, outputs: Vec<Type>, io: ShaderIO) -> Result<Shader, reqwest::Error> {
future_executor::block_on(compile(network, inputs, output, io)) future_executor::block_on(compile(networks, inputs, outputs, io))
} }
// TODO: should we add the entry point as a field? // TODO: should we add the entry point as a field?
@ -25,6 +25,6 @@ pub fn compile_sync(network: ProtoNetwork, inputs: Vec<Type>, output: Type, io:
pub struct Shader { pub struct Shader {
pub spirv_binary: Vec<u32>, pub spirv_binary: Vec<u32>,
pub input_types: Vec<Type>, pub input_types: Vec<Type>,
pub output_type: Type, pub output_types: Vec<Type>,
pub io: ShaderIO, pub io: ShaderIO,
} }

View file

@ -36,7 +36,7 @@ fn main() {
output: ShaderInput::OutputBuffer((), concrete!(&mut [u32])), output: ShaderInput::OutputBuffer((), concrete!(&mut [u32])),
}; };
let compile_request = CompileRequest::new(proto_network, vec![concrete!(u32)], concrete!(u32), io); let compile_request = CompileRequest::new(vec![proto_network], vec![concrete!(u32)], vec![concrete!(u32)], io);
let response = client let response = client
.post("http://localhost:3000/compile/spirv") .post("http://localhost:3000/compile/spirv")
.timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(30))

View file

@ -16,3 +16,4 @@ serde = { version = "1.0", features = ["derive"] }
tempfile = "3.3.0" tempfile = "3.3.0"
anyhow = "1.0.68" anyhow = "1.0.68"
futures = "0.3" futures = "0.3"
tower-http = { version = "0.4.0", features = ["cors"] }

View file

@ -1,6 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use gpu_compiler_bin_wrapper::CompileRequest; use gpu_compiler_bin_wrapper::CompileRequest;
use tower_http::cors::CorsLayer;
use axum::{ use axum::{
extract::{Json, State}, extract::{Json, State},
@ -24,6 +25,7 @@ async fn main() {
.route("/", get(|| async { "Hello from compilation server!" })) .route("/", get(|| async { "Hello from compilation server!" }))
.route("/compile", get(|| async { "Supported targets: spirv" })) .route("/compile", get(|| async { "Supported targets: spirv" }))
.route("/compile/spirv", post(post_compile_spirv)) .route("/compile/spirv", post(post_compile_spirv))
.layer(CorsLayer::permissive())
.with_state(shared_state); .with_state(shared_state);
// run it with hyper on localhost:3000 // run it with hyper on localhost:3000

View file

@ -107,6 +107,15 @@ impl<'i, 's: 'i, I: 'i, O: 'i, N: Node<'i, I, Output = O>> Node<'i, I> for &'s N
(**self).eval(input) (**self).eval(input)
} }
} }
#[cfg(feature = "alloc")]
impl<'i, 's: 'i, I: 'i, O: 'i, N: Node<'i, I, Output = O>> Node<'i, I> for Box<N> {
type Output = O;
fn eval(&'i self, input: I) -> Self::Output {
(**self).eval(input)
}
}
impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'a> Node<'a, I, Output = O> { impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'a> Node<'a, I, Output = O> {
type Output = O; type Output = O;

View file

@ -192,6 +192,14 @@ pub trait Sample {
fn sample(&self, pos: DVec2, area: DVec2) -> Option<Self::Pixel>; fn sample(&self, pos: DVec2, area: DVec2) -> Option<Self::Pixel>;
} }
impl<'i, T: Sample> Sample for &'i T {
type Pixel = T::Pixel;
fn sample(&self, pos: DVec2, area: DVec2) -> Option<Self::Pixel> {
(**self).sample(pos, area)
}
}
// TODO: We might rename this to Bitmap at some point // TODO: We might rename this to Bitmap at some point
pub trait Raster { pub trait Raster {
type Pixel: Pixel; type Pixel: Pixel;
@ -200,6 +208,38 @@ pub trait Raster {
fn get_pixel(&self, x: u32, y: u32) -> Option<Self::Pixel>; fn get_pixel(&self, x: u32, y: u32) -> Option<Self::Pixel>;
} }
impl<'i, T: Raster> Raster for &'i T {
type Pixel = T::Pixel;
fn width(&self) -> u32 {
(**self).width()
}
fn height(&self) -> u32 {
(**self).height()
}
fn get_pixel(&self, x: u32, y: u32) -> Option<Self::Pixel> {
(**self).get_pixel(x, y)
}
}
impl<'i, T: Raster> Raster for &'i mut T {
type Pixel = T::Pixel;
fn width(&self) -> u32 {
(**self).width()
}
fn height(&self) -> u32 {
(**self).height()
}
fn get_pixel(&self, x: u32, y: u32) -> Option<Self::Pixel> {
(**self).get_pixel(x, y)
}
}
pub trait RasterMut: Raster { pub trait RasterMut: Raster {
fn get_pixel_mut(&mut self, x: u32, y: u32) -> Option<&mut Self::Pixel>; fn get_pixel_mut(&mut self, x: u32, y: u32) -> Option<&mut Self::Pixel>;
fn set_pixel(&mut self, x: u32, y: u32, pixel: Self::Pixel) { fn set_pixel(&mut self, x: u32, y: u32, pixel: Self::Pixel) {
@ -215,6 +255,12 @@ pub trait RasterMut: Raster {
} }
} }
impl<'i, T: RasterMut + Raster> RasterMut for &'i mut T {
fn get_pixel_mut(&mut self, x: u32, y: u32) -> Option<&mut Self::Pixel> {
(*self).get_pixel_mut(x, y)
}
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct MapNode<MapFn> { pub struct MapNode<MapFn> {
map_fn: MapFn, map_fn: MapFn,

View file

@ -363,6 +363,19 @@ fn invert_image(color: Color) -> Color {
color.to_linear_srgb() color.to_linear_srgb()
} }
// TODO replace with trait based implementation
impl<'i> Node<'i, &'i Color> for InvertRGBNode {
type Output = Color;
fn eval(&'i self, color: &'i Color) -> Self::Output {
let color = color.to_gamma_srgb();
let color = color.map_rgb(|c| color.a() - c);
color.to_linear_srgb()
}
}
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct ThresholdNode<MinLuminance, MaxLuminance, LuminanceCalc> { pub struct ThresholdNode<MinLuminance, MaxLuminance, LuminanceCalc> {
min_luminance: MinLuminance, min_luminance: MinLuminance,

View file

@ -1,34 +1,158 @@
use crate::Node; use crate::Node;
use core::cell::RefMut;
use core::marker::PhantomData; use core::marker::PhantomData;
use core::ops::{DerefMut, Index, IndexMut}; use core::ops::{Deref, DerefMut, Index, IndexMut};
struct SetNode<S, I, Storage, Index> { pub struct SetNode<Storage> {
storage: Storage, storage: Storage,
index: Index,
_s: PhantomData<S>,
_i: PhantomData<I>,
} }
/*
#[node_macro::node_fn(SetNode<_S, _I>)] #[node_macro::node_fn(SetNode)]
fn set_node<T, _S, _I>(value: T, storage: &'input mut _S, index: _I) fn set_node<_T, _I, A: 'input>(input: (_T, _I), mut storage: RefMut<'input, A>)
where where
_S: IndexMut<_I>, A: DerefMut,
_S::Output: DerefMut<Target = T> + Sized, A::Target: IndexMut<_I, Output = _T>,
{ {
*storage.index_mut(index).deref_mut() = value; let (value, index) = input;
*storage.deref_mut().index_mut(index).deref_mut() = value;
}*/
impl<'input, T: 'input, I: 'input, A: 'input + 'input, S0: 'input> Node<'input, (T, I)> for SetNode<S0>
where
A: DerefMut,
A::Target: IndexMut<I, Output = T>,
S0: for<'any_input> Node<'input, (), Output = A>,
{
type Output = ();
#[inline]
fn eval(&'input self, input: (T, I)) -> Self::Output {
let mut storage = self.storage.eval(());
let (value, index) = input;
*storage.deref_mut().index_mut(index).deref_mut() = value;
}
}
impl<'input, S0: 'input> SetNode<S0> {
pub const fn new(storage: S0) -> Self {
Self { storage }
}
} }
struct GetNode<S, Storage> { pub struct ExtractXNode {}
#[node_macro::node_fn(ExtractXNode)]
fn extract_x_node(input: glam::UVec3) -> usize {
input.x as usize
}
pub struct SetOwnedNode<Storage> {
storage: core::cell::RefCell<Storage>,
}
impl<Storage> SetOwnedNode<Storage> {
pub fn new(storage: Storage) -> Self {
Self {
storage: core::cell::RefCell::new(storage),
}
}
}
impl<'input, I: 'input, T: 'input, Storage, A: ?Sized> Node<'input, (T, I)> for SetOwnedNode<Storage>
where
Storage: DerefMut<Target = A> + 'input,
A: IndexMut<I, Output = T> + 'input,
{
type Output = ();
fn eval(&'input self, input: (T, I)) -> Self::Output {
let (value, index) = input;
*self.storage.borrow_mut().index_mut(index) = value;
}
}
pub struct GetNode<Storage> {
storage: Storage, storage: Storage,
_s: PhantomData<S>,
} }
#[node_macro::node_fn(GetNode<_S>)] impl<Storage> GetNode<Storage> {
fn get_node<_S, I>(index: I, storage: &'input _S) -> &'input _S::Output pub fn new(storage: Storage) -> Self {
where Self { storage }
_S: Index<I>, }
_S::Output: Sized, }
{
storage.index(index) impl<'input, I: 'input, T: 'input, Storage, SNode, A: ?Sized> Node<'input, I> for GetNode<SNode>
where
SNode: Node<'input, (), Output = Storage>,
Storage: Deref<Target = A> + 'input,
A: Index<I, Output = T> + 'input,
T: Clone,
{
type Output = T;
fn eval(&'input self, index: I) -> Self::Output {
let storage = self.storage.eval(());
storage.deref().index(index).deref().clone()
}
}
#[cfg(test)]
mod test {
use crate::value::{CopiedNode, OnceCellNode, RefCellMutNode, UnsafeMutValueNode, ValueNode};
use crate::Node;
use super::*;
#[test]
fn get_node_array() {
let storage = [1, 2, 3];
let node = GetNode::new(CopiedNode::new(&storage));
assert_eq!((&node as &dyn Node<'_, usize, Output = i32>).eval(1), 2);
}
#[test]
fn get_node_vec() {
let storage = vec![1, 2, 3];
let node = GetNode::new(CopiedNode::new(&storage));
assert_eq!(node.eval(1), 2);
}
#[test]
fn get_node_slice() {
let storage: &[i32] = &[1, 2, 3];
let node = GetNode::new(CopiedNode::new(storage));
let _ = &node as &dyn Node<'_, usize, Output = i32>;
assert_eq!(node.eval(1), 2);
}
#[test]
fn set_node_slice() {
let mut backing_storage = [1, 2, 3];
let storage: &mut [i32] = &mut backing_storage;
let storage_node = OnceCellNode::new(storage);
let node = SetNode::new(storage_node);
node.eval((4, 1));
assert_eq!(backing_storage, [1, 4, 3]);
}
#[test]
fn set_owned_node_array() {
let mut storage = [1, 2, 3];
let node = SetOwnedNode::new(&mut storage);
node.eval((4, 1));
assert_eq!(storage, [1, 4, 3]);
}
#[test]
fn set_owned_node_vec() {
let mut storage = vec![1, 2, 3];
let node = SetOwnedNode::new(&mut storage);
node.eval((4, 1));
assert_eq!(storage, [1, 4, 3]);
}
#[test]
fn set_owned_node_slice() {
let mut backing_storage = [1, 2, 3];
let storage: &mut [i32] = &mut backing_storage;
let node = SetOwnedNode::new(storage);
let node = &node as &dyn Node<'_, (i32, usize), Output = ()>;
node.eval((4, 1));
assert_eq!(backing_storage, [1, 4, 3]);
}
} }

View file

@ -39,6 +39,7 @@ pub struct AsyncComposeNode<First, Second, I> {
phantom: PhantomData<I>, phantom: PhantomData<I>,
} }
#[cfg(feature = "alloc")]
impl<'i, 'f: 'i, 's: 'i, Input: 'static, First, Second> Node<'i, Input> for AsyncComposeNode<First, Second, Input> impl<'i, 'f: 'i, 's: 'i, Input: 'static, First, Second> Node<'i, Input> for AsyncComposeNode<First, Second, Input>
where where
First: Node<'i, Input>, First: Node<'i, Input>,
@ -54,6 +55,7 @@ where
} }
} }
#[cfg(feature = "alloc")]
impl<'i, First, Second, Input: 'i> AsyncComposeNode<First, Second, Input> impl<'i, First, Second, Input: 'i> AsyncComposeNode<First, Second, Input>
where where
First: Node<'i, Input>, First: Node<'i, Input>,
@ -77,6 +79,7 @@ pub trait Then<'i, Input: 'i>: Sized {
impl<'i, First: Node<'i, Input>, Input: 'i> Then<'i, Input> for First {} impl<'i, First: Node<'i, Input>, Input: 'i> Then<'i, Input> for First {}
#[cfg(feature = "alloc")]
pub trait AndThen<'i, Input: 'i>: Sized { pub trait AndThen<'i, Input: 'i>: Sized {
fn and_then<Second>(self, second: Second) -> AsyncComposeNode<Self, Second, Input> fn and_then<Second>(self, second: Second) -> AsyncComposeNode<Self, Second, Input>
where where
@ -88,6 +91,7 @@ pub trait AndThen<'i, Input: 'i>: Sized {
} }
} }
#[cfg(feature = "alloc")]
impl<'i, First: Node<'i, Input>, Input: 'i> AndThen<'i, Input> for First {} impl<'i, First: Node<'i, Input>, Input: 'i> AndThen<'i, Input> for First {}
pub struct ConsNode<I: From<()>, Root>(pub Root, PhantomData<I>); pub struct ConsNode<I: From<()>, Root>(pub Root, PhantomData<I>);
@ -108,6 +112,38 @@ impl<'i, Root: Node<'i, I>, I: 'i + From<()>> ConsNode<I, Root> {
} }
} }
pub struct ApplyNode<O, N> {
pub node: N,
_o: PhantomData<O>,
}
/*
#[node_macro::node_fn(ApplyNode)]
fn apply<In, N>(input: In, node: &'any_input N) -> ()
where
// TODO: try to allows this to return output other than ()
N: for<'any_input> Node<'any_input, In, Output = ()>,
{
node.eval(input)
}
*/
impl<'input, In: 'input, N: 'input, S0: 'input, O: 'input> Node<'input, In> for ApplyNode<O, S0>
where
N: Node<'input, In, Output = O>,
S0: Node<'input, (), Output = &'input N>,
{
type Output = <N as Node<'input, In>>::Output;
#[inline]
fn eval(&'input self, input: In) -> Self::Output {
let node = self.node.eval(());
node.eval(input)
}
}
impl<'input, S0: 'input, O: 'static> ApplyNode<O, S0> {
pub const fn new(node: S0) -> Self {
Self { node, _o: PhantomData }
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{ops::IdNode, value::ValueNode}; use crate::{ops::IdNode, value::ValueNode};
@ -134,4 +170,15 @@ mod test {
assert_eq!(compose.eval(()), &5); assert_eq!(compose.eval(()), &5);
} }
#[test]
fn test_apply() {
let mut array = [1, 2, 3];
let slice = &mut array;
let set_node = crate::storage::SetOwnedNode::new(slice);
let apply = ApplyNode::new(ValueNode::new(set_node));
assert_eq!(apply.eval((1, 2)), ());
}
} }

View file

@ -1,6 +1,10 @@
use crate::Node; use crate::Node;
use core::marker::PhantomData; use core::{
borrow::BorrowMut,
cell::{Cell, RefCell, RefMut},
marker::PhantomData,
};
#[derive(Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] #[derive(Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct IntNode<const N: u32>; pub struct IntNode<const N: u32>;
@ -13,7 +17,7 @@ impl<'i, const N: u32> Node<'i, ()> for IntNode<N> {
} }
} }
#[derive(Default, Debug)] #[derive(Default, Debug, Clone, Copy)]
pub struct ValueNode<T>(pub T); pub struct ValueNode<T>(pub T);
impl<'i, T: 'i> Node<'i, ()> for ValueNode<T> { impl<'i, T: 'i> Node<'i, ()> for ValueNode<T> {
@ -35,12 +39,62 @@ impl<T> From<T> for ValueNode<T> {
ValueNode::new(value) ValueNode::new(value)
} }
} }
impl<T: Clone> Clone for ValueNode<T> {
fn clone(&self) -> Self { #[derive(Default, Debug, Clone)]
Self(self.0.clone()) pub struct RefCellMutNode<T>(pub RefCell<T>);
impl<'i, T: 'i> Node<'i, ()> for RefCellMutNode<T> {
type Output = RefMut<'i, T>;
#[inline(always)]
fn eval(&'i self, _input: ()) -> Self::Output {
#[cfg(not(target_arch = "spirv"))]
let a = self.0.borrow_mut();
#[cfg(target_arch = "spirv")]
let a = unsafe { self.0.try_borrow_mut().unwrap_unchecked() };
a
}
}
impl<T> RefCellMutNode<T> {
pub const fn new(value: T) -> RefCellMutNode<T> {
RefCellMutNode(RefCell::new(value))
}
}
/// #Safety: Never use this as it is unsound.
#[derive(Default, Debug)]
pub struct UnsafeMutValueNode<T>(pub T);
/// #Safety: Never use this as it is unsound.
impl<'i, T: 'i> Node<'i, ()> for UnsafeMutValueNode<T> {
type Output = &'i mut T;
#[inline(always)]
fn eval(&'i self, _input: ()) -> Self::Output {
unsafe { &mut *(&self.0 as &T as *const T as *mut T) }
}
}
impl<T> UnsafeMutValueNode<T> {
pub const fn new(value: T) -> UnsafeMutValueNode<T> {
UnsafeMutValueNode(value)
}
}
#[derive(Default)]
pub struct OnceCellNode<T>(pub Cell<T>);
impl<'i, T: Default + 'i> Node<'i, ()> for OnceCellNode<T> {
type Output = T;
#[inline(always)]
fn eval(&'i self, _input: ()) -> Self::Output {
self.0.replace(T::default())
}
}
impl<T> OnceCellNode<T> {
pub const fn new(value: T) -> OnceCellNode<T> {
OnceCellNode(Cell::new(value))
} }
} }
impl<T: Clone + Copy> Copy for ValueNode<T> {}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct ClonedNode<T: Clone>(pub T); pub struct ClonedNode<T: Clone>(pub T);
@ -75,6 +129,7 @@ impl<'i, T: Clone + 'i> Node<'i, ()> for DebugClonedNode<T> {
#[inline(always)] #[inline(always)]
fn eval(&'i self, _input: ()) -> Self::Output { fn eval(&'i self, _input: ()) -> Self::Output {
// KEEP THIS `debug!()` - It acts as the output for the debug node itself // KEEP THIS `debug!()` - It acts as the output for the debug node itself
#[cfg(not(target_arch = "spirv"))]
log::debug!("DebugClonedNode::eval"); log::debug!("DebugClonedNode::eval");
self.0.clone() self.0.clone()

View file

@ -603,6 +603,7 @@ dependencies = [
"num-traits", "num-traits",
"once_cell", "once_cell",
"rand_chacha", "rand_chacha",
"rustybuzz",
"serde", "serde",
"specta", "specta",
"spin", "spin",
@ -1153,6 +1154,22 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "rustybuzz"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab9e34ecf6900625412355a61bda0bd68099fe674de707c67e5e4aed2c05e489"
dependencies = [
"bitflags",
"bytemuck",
"smallvec",
"ttf-parser",
"unicode-bidi-mirroring",
"unicode-ccc",
"unicode-general-category",
"unicode-script",
]
[[package]] [[package]]
name = "ryu" name = "ryu"
version = "1.0.12" version = "1.0.12"
@ -1486,6 +1503,12 @@ dependencies = [
"once_cell", "once_cell",
] ]
[[package]]
name = "ttf-parser"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "375812fa44dab6df41c195cd2f7fecb488f6c09fbaafb62807488cefab642bff"
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.16.0" version = "1.16.0"
@ -1557,12 +1580,36 @@ dependencies = [
"unic-common", "unic-common",
] ]
[[package]]
name = "unicode-bidi-mirroring"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56d12260fb92d52f9008be7e4bca09f584780eb2266dc8fecc6a192bec561694"
[[package]]
name = "unicode-ccc"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc2520efa644f8268dce4dcd3050eaa7fc044fca03961e9998ac7e2e92b77cf1"
[[package]]
name = "unicode-general-category"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2281c8c1d221438e373249e065ca4989c4c36952c211ff21a0ee91c44a3869e7"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.6" version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc"
[[package]]
name = "unicode-script"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d817255e1bed6dfd4ca47258685d14d2bdcfbc64fdc9e3819bd5848057b8ecc"
[[package]] [[package]]
name = "unicode-width" name = "unicode-width"
version = "0.1.10" version = "0.1.10"

View file

@ -6,7 +6,7 @@ use std::io::Write;
pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manifest_path: &str) -> anyhow::Result<Vec<u8>> { pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
let serialized_graph = serde_json::to_string(&gpu_executor::CompileRequest { let serialized_graph = serde_json::to_string(&gpu_executor::CompileRequest {
network: request.network.clone(), networks: request.networks.clone(),
io: request.shader_io.clone(), io: request.shader_io.clone(),
})?; })?;
@ -43,23 +43,23 @@ pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manife
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct CompileRequest { pub struct CompileRequest {
network: graph_craft::proto::ProtoNetwork, networks: Vec<graph_craft::proto::ProtoNetwork>,
input_types: Vec<Type>, input_types: Vec<Type>,
output_type: Type, output_types: Vec<Type>,
shader_io: ShaderIO, shader_io: ShaderIO,
} }
impl CompileRequest { impl CompileRequest {
pub fn new(network: ProtoNetwork, input_types: Vec<Type>, output_type: Type, io: ShaderIO) -> Self { pub fn new(networks: Vec<ProtoNetwork>, input_types: Vec<Type>, output_types: Vec<Type>, io: ShaderIO) -> Self {
// TODO: add type checking // TODO: add type checking
// for (input, buffer) in input_types.iter().zip(io.inputs.iter()) { // for (input, buffer) in input_types.iter().zip(io.inputs.iter()) {
// assert_eq!(input, &buffer.ty()); // assert_eq!(input, &buffer.ty());
// } // }
// assert_eq!(output_type, io.output.ty()); // assert_eq!(output_type, io.output.ty());
Self { Self {
network, networks,
input_types, input_types,
output_type, output_types,
shader_io: io, shader_io: io,
} }
} }

View file

@ -26,7 +26,7 @@ impl Metadata {
} }
} }
pub fn create_files(metadata: &Metadata, network: &ProtoNetwork, compile_dir: &Path, io: &ShaderIO) -> anyhow::Result<()> { pub fn create_files(metadata: &Metadata, networks: &[ProtoNetwork], compile_dir: &Path, io: &ShaderIO) -> anyhow::Result<()> {
let src = compile_dir.join("src"); let src = compile_dir.join("src");
let cargo_file = compile_dir.join("Cargo.toml"); let cargo_file = compile_dir.join("Cargo.toml");
let cargo_toml = create_cargo_toml(metadata)?; let cargo_toml = create_cargo_toml(metadata)?;
@ -46,7 +46,7 @@ pub fn create_files(metadata: &Metadata, network: &ProtoNetwork, compile_dir: &P
} }
} }
let lib = src.join("lib.rs"); let lib = src.join("lib.rs");
let shader = serialize_gpu(network, io)?; let shader = serialize_gpu(networks, io)?;
eprintln!("{}", shader); eprintln!("{}", shader);
std::fs::write(lib, shader)?; std::fs::write(lib, shader)?;
Ok(()) Ok(())
@ -67,20 +67,21 @@ fn constant_attribute(constant: &GPUConstant) -> &'static str {
} }
} }
pub fn construct_argument(input: &ShaderInput<()>, position: u32) -> String { pub fn construct_argument(input: &ShaderInput<()>, position: u32, binding_offset: u32) -> String {
match input { let line = match input {
ShaderInput::Constant(constant) => format!("#[spirv({})] i{}: {},", constant_attribute(constant), position, constant.ty()), ShaderInput::Constant(constant) => format!("#[spirv({})] i{}: {}", constant_attribute(constant), position, constant.ty()),
ShaderInput::UniformBuffer(_, ty) => { ShaderInput::UniformBuffer(_, ty) => {
format!("#[spirv(uniform, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,) format!("#[spirv(uniform, descriptor_set = 0, binding = {})] i{}: &[{}]", position + binding_offset, position, ty,)
} }
ShaderInput::StorageBuffer(_, ty) | ShaderInput::ReadBackBuffer(_, ty) => { ShaderInput::StorageBuffer(_, ty) | ShaderInput::ReadBackBuffer(_, ty) => {
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,) format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &[{}]", position + binding_offset, position, ty,)
} }
ShaderInput::OutputBuffer(_, ty) => { ShaderInput::OutputBuffer(_, ty) => {
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &mut[{}]", position, position, ty,) format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] o{}: &mut[{}]", position + binding_offset, position, ty,)
} }
ShaderInput::WorkGroupMemory(_, ty) => format!("#[spirv(workgroup_memory] i{}: {}", position, ty,), ShaderInput::WorkGroupMemory(_, ty) => format!("#[spirv(workgroup_memory] i{}: {}", position, ty,),
} };
line.replace("glam::u32::uvec3::UVec3", "spirv_std::glam::UVec3")
} }
struct GpuCompiler { struct GpuCompiler {
@ -88,10 +89,10 @@ struct GpuCompiler {
} }
impl SpirVCompiler for GpuCompiler { impl SpirVCompiler for GpuCompiler {
fn compile(&self, network: ProtoNetwork, io: &ShaderIO) -> anyhow::Result<gpu_executor::Shader> { fn compile(&self, networks: &[ProtoNetwork], io: &ShaderIO) -> anyhow::Result<gpu_executor::Shader> {
let metadata = Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]); let metadata = Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
create_files(&metadata, &network, &self.compile_dir, io)?; create_files(&metadata, networks, &self.compile_dir, io)?;
let result = compile(&self.compile_dir)?; let result = compile(&self.compile_dir)?;
let bytes = std::fs::read(result.module.unwrap_single())?; let bytes = std::fs::read(result.module.unwrap_single())?;
@ -105,50 +106,80 @@ impl SpirVCompiler for GpuCompiler {
} }
} }
pub fn serialize_gpu(network: &ProtoNetwork, io: &ShaderIO) -> anyhow::Result<String> { pub fn serialize_gpu(networks: &[ProtoNetwork], io: &ShaderIO) -> anyhow::Result<String> {
fn nid(id: &u64) -> String { fn nid(id: &u64) -> String {
format!("n{id}") format!("n{id}")
} }
dbg!(&network);
dbg!(&io); dbg!(&io);
let inputs = io.inputs.iter().enumerate().map(|(i, input)| construct_argument(input, i as u32)).collect::<Vec<_>>(); let mut inputs = io
.inputs
.iter()
.filter(|x| !x.is_output())
.enumerate()
.map(|(i, input)| construct_argument(input, i as u32, 0))
.collect::<Vec<_>>();
let offset = inputs.len() as u32;
inputs.extend(io.inputs.iter().filter(|x| x.is_output()).enumerate().map(|(i, input)| construct_argument(input, i as u32, offset)));
let mut nodes = Vec::new(); let mut nodes = Vec::new();
let mut input_nodes = Vec::new(); let mut input_nodes = Vec::new();
#[derive(serde::Serialize)] let mut output_nodes = Vec::new();
struct Node { for network in networks {
id: String, dbg!(&network);
fqn: String, //assert_eq!(network.inputs.len(), io.inputs.iter().filter(|x| !x.is_output()).count());
args: Vec<String>, #[derive(serde::Serialize, Debug)]
} struct Node {
for id in network.inputs.iter() { id: String,
let Some((_, node)) = network.nodes.iter().find(|(i, _)| i == id) else { index: usize,
fqn: String,
args: Vec<String>,
}
for (i, id) in network.inputs.iter().enumerate() {
let Some((_, node)) = network.nodes.iter().find(|(i, _)| i == id) else {
anyhow::bail!("Input node not found"); anyhow::bail!("Input node not found");
}; };
let fqn = &node.identifier.name; let fqn = &node.identifier.name;
let id = nid(id); let id = nid(id);
input_nodes.push(Node { let node = Node {
id, id: id.clone(),
fqn: fqn.to_string().split("<").next().unwrap().to_owned(), index: i,
args: node.construction_args.new_function_args(), fqn: fqn.to_string().split('<').next().unwrap().to_owned(),
}); args: node.construction_args.new_function_args(),
} };
dbg!(&node);
for (ref id, node) in network.nodes.iter() { if !io.inputs[i].is_output() {
if network.inputs.contains(id) { if input_nodes.iter().any(|x: &Node| x.id == id) {
continue; continue;
}
input_nodes.push(node);
}
} }
let fqn = &node.identifier.name; for (ref id, node) in network.nodes.iter() {
let id = nid(id); if network.inputs.contains(id) {
continue;
}
nodes.push(Node { let fqn = &node.identifier.name;
id, let id = nid(id);
fqn: fqn.to_string().split("<").next().unwrap().to_owned(),
args: node.construction_args.new_function_args(), if nodes.iter().any(|x: &Node| x.id == id) {
}); continue;
}
nodes.push(Node {
id,
index: 0,
fqn: fqn.to_string().split("<").next().unwrap().to_owned(),
args: node.construction_args.new_function_args(),
});
}
let output = nid(&network.output);
output_nodes.push(output);
} }
dbg!(&input_nodes);
let template = include_str!("templates/spirv-template.rs"); let template = include_str!("templates/spirv-template.rs");
let mut tera = tera::Tera::default(); let mut tera = tera::Tera::default();
@ -156,8 +187,8 @@ pub fn serialize_gpu(network: &ProtoNetwork, io: &ShaderIO) -> anyhow::Result<St
let mut context = Context::new(); let mut context = Context::new();
context.insert("inputs", &inputs); context.insert("inputs", &inputs);
context.insert("input_nodes", &input_nodes); context.insert("input_nodes", &input_nodes);
context.insert("output_nodes", &output_nodes);
context.insert("nodes", &nodes); context.insert("nodes", &nodes);
context.insert("last_node", &nid(&network.output));
context.insert("compute_threads", &64); context.insert("compute_threads", &64);
Ok(tera.render("spirv", &context)?) Ok(tera.render("spirv", &context)?)
} }
@ -171,9 +202,9 @@ pub fn compile(dir: &Path) -> Result<spirv_builder::CompileResult, spirv_builder
.preserve_bindings(true) .preserve_bindings(true)
.release(true) .release(true)
.spirv_metadata(SpirvMetadata::Full) .spirv_metadata(SpirvMetadata::Full)
.extra_arg("no-early-report-zombies") //.extra_arg("no-early-report-zombies")
.extra_arg("no-infer-storage-classes") //.extra_arg("no-infer-storage-classes")
.extra_arg("spirt-passes=qptr") //.extra_arg("spirt-passes=qptr")
.build()?; .build()?;
Ok(result) Ok(result)

View file

@ -13,7 +13,7 @@ fn main() -> anyhow::Result<()> {
let metadata = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]); let metadata = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
compiler::create_files(&metadata, &request.network, &compile_dir, &request.io)?; compiler::create_files(&metadata, &request.networks, &compile_dir, &request.io)?;
let result = compiler::compile(&compile_dir)?; let result = compiler::compile(&compile_dir)?;
let bytes = std::fs::read(result.module.unwrap_single())?; let bytes = std::fs::read(result.module.unwrap_single())?;

View file

@ -4,32 +4,38 @@
#[cfg(target_arch = "spirv")] #[cfg(target_arch = "spirv")]
extern crate spirv_std; extern crate spirv_std;
#[cfg(target_arch = "spirv")] //#[cfg(target_arch = "spirv")]
pub mod gpu { //pub mod gpu {
use super::*; //use super::*;
use spirv_std::spirv; use spirv_std::spirv;
use spirv_std::glam::UVec3; use spirv_std::glam::UVec3;
#[allow(unused)] #[allow(unused)]
#[spirv(compute(threads({{compute_threads}})))] #[spirv(compute(threads({{compute_threads}})))]
pub fn eval ( pub fn eval (
#[spirv(global_invocation_id)] _global_index: UVec3,
{% for input in inputs %} {% for input in inputs %}
{{input}} {{input}},
{% endfor %} {% endfor %}
) { ) {
use graphene_core::Node; use graphene_core::Node;
/*
{% for input in input_nodes %} {% for input in input_nodes %}
let i{{loop.index0}} = graphene_core::value::CopiedNode::new(i{{loop.index0}}); let i{{input.index}} = graphene_core::value::CopiedNode::new(i{{input.index}});
let _{{input.id}} = {{input.fqn}}::new({% for arg in input.args %}{{arg}}, {% endfor %}); let _{{input.id}} = {{input.fqn}}::new({% for arg in input.args %}{{arg}}, {% endfor %});
let {{input.id}} = graphene_core::structural::ComposeNode::new(i{{loop.index0}}, _{{input.id}}); let {{input.id}} = graphene_core::structural::ComposeNode::new(i{{input.index}}, _{{input.id}});
{% endfor %} {% endfor %}
*/
{% for node in nodes %} {% for node in nodes %}
let {{node.id}} = {{node.fqn}}::new({% for arg in node.args %}{{arg}}, {% endfor %}); let {{node.id}} = {{node.fqn}}::new({% for arg in node.args %}{{arg}}, {% endfor %});
{% endfor %} {% endfor %}
let output = {{last_node}}.eval(());
// TODO: Write output to buffer
{% for output in output_nodes %}
let v = {{output}}.eval(());
o{{loop.index0}}[_global_index.x as usize] = v;
{% endfor %}
// TODO: Write output to buffer
} }
} //}

View file

@ -1,13 +1,15 @@
use bytemuck::{Pod, Zeroable};
use graph_craft::proto::ProtoNetwork; use graph_craft::proto::ProtoNetwork;
use graphene_core::*; use graphene_core::*;
use anyhow::Result; use anyhow::Result;
use dyn_any::StaticType; use dyn_any::{StaticType, StaticTypeSized};
use futures::Future; use futures::Future;
use glam::UVec3; use glam::UVec3;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::Cow; use std::borrow::Cow;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
type ReadBackFuture = Pin<Box<dyn Future<Output = Result<Vec<u8>>>>>; type ReadBackFuture = Pin<Box<dyn Future<Output = Result<Vec<u8>>>>>;
@ -20,18 +22,18 @@ pub trait GpuExecutor {
fn create_uniform_buffer<T: ToUniformBuffer>(&self, data: T) -> Result<ShaderInput<Self::BufferHandle>>; fn create_uniform_buffer<T: ToUniformBuffer>(&self, data: T) -> Result<ShaderInput<Self::BufferHandle>>;
fn create_storage_buffer<T: ToStorageBuffer>(&self, data: T, options: StorageBufferOptions) -> Result<ShaderInput<Self::BufferHandle>>; fn create_storage_buffer<T: ToStorageBuffer>(&self, data: T, options: StorageBufferOptions) -> Result<ShaderInput<Self::BufferHandle>>;
fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result<ShaderInput<Self::BufferHandle>>; fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result<ShaderInput<Self::BufferHandle>>;
fn create_compute_pass(&self, layout: &PipelineLayout<Self>, read_back: Option<ShaderInput<Self::BufferHandle>>, instances: u32) -> Result<Self::CommandBuffer>; fn create_compute_pass(&self, layout: &PipelineLayout<Self>, read_back: Option<Arc<ShaderInput<Self::BufferHandle>>>, instances: u32) -> Result<Self::CommandBuffer>;
fn execute_compute_pipeline(&self, encoder: Self::CommandBuffer) -> Result<()>; fn execute_compute_pipeline(&self, encoder: Self::CommandBuffer) -> Result<()>;
fn read_output_buffer(&self, buffer: ShaderInput<Self::BufferHandle>) -> Result<ReadBackFuture>; fn read_output_buffer(&self, buffer: Arc<ShaderInput<Self::BufferHandle>>) -> ReadBackFuture;
} }
pub trait SpirVCompiler { pub trait SpirVCompiler {
fn compile(&self, network: ProtoNetwork, io: &ShaderIO) -> Result<Shader>; fn compile(&self, network: &[ProtoNetwork], io: &ShaderIO) -> Result<Shader>;
} }
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CompileRequest { pub struct CompileRequest {
pub network: ProtoNetwork, pub networks: Vec<ProtoNetwork>,
pub io: ShaderIO, pub io: ShaderIO,
} }
@ -101,6 +103,10 @@ impl<BufferHandle> ShaderInput<BufferHandle> {
ShaderInput::ReadBackBuffer(_, ty) => ty.clone(), ShaderInput::ReadBackBuffer(_, ty) => ty.clone(),
} }
} }
pub fn is_output(&self) -> bool {
matches!(self, ShaderInput::OutputBuffer(_, _))
}
} }
pub struct Shader<'a> { pub struct Shader<'a> {
@ -119,6 +125,7 @@ pub struct StorageBufferOptions {
pub cpu_writable: bool, pub cpu_writable: bool,
pub gpu_writable: bool, pub gpu_writable: bool,
pub cpu_readable: bool, pub cpu_readable: bool,
pub storage: bool,
} }
pub trait ToUniformBuffer: StaticType { pub trait ToUniformBuffer: StaticType {
@ -127,13 +134,22 @@ pub trait ToUniformBuffer: StaticType {
} }
pub trait ToStorageBuffer: StaticType { pub trait ToStorageBuffer: StaticType {
type StorageBufferHandle;
fn to_bytes(&self) -> Cow<[u8]>; fn to_bytes(&self) -> Cow<[u8]>;
fn ty(&self) -> Type;
}
impl<T: Pod + Zeroable + StaticTypeSized> ToStorageBuffer for Vec<T> {
fn to_bytes(&self) -> Cow<[u8]> {
Cow::Borrowed(bytemuck::cast_slice(self.as_slice()))
}
fn ty(&self) -> Type {
concrete!(T)
}
} }
/// Collection of all arguments that are passed to the shader. /// Collection of all arguments that are passed to the shader.
pub struct Bindgroup<E: GpuExecutor + ?Sized> { pub struct Bindgroup<E: GpuExecutor + ?Sized> {
pub buffers: Vec<ShaderInput<E::BufferHandle>>, pub buffers: Vec<Arc<ShaderInput<E::BufferHandle>>>,
} }
/// A struct representing a compute pipeline. /// A struct representing a compute pipeline.
@ -141,7 +157,7 @@ pub struct PipelineLayout<E: GpuExecutor + ?Sized> {
pub shader: E::ShaderHandle, pub shader: E::ShaderHandle,
pub entry_point: String, pub entry_point: String,
pub bind_group: Bindgroup<E>, pub bind_group: Bindgroup<E>,
pub output_buffer: ShaderInput<E::BufferHandle>, pub output_buffer: Arc<ShaderInput<E::BufferHandle>>,
} }
/// Extracts arguments from the function arguments and wraps them in a node. /// Extracts arguments from the function arguments and wraps them in a node.
@ -185,6 +201,7 @@ fn storage_node<T: ToStorageBuffer, E: GpuExecutor>(data: T, executor: &'input E
cpu_writable: false, cpu_writable: false,
gpu_writable: true, gpu_writable: true,
cpu_readable: false, cpu_readable: false,
storage: true,
}, },
) )
.unwrap() .unwrap()
@ -216,8 +233,8 @@ pub struct CreateComputePassNode<Executor, Output, Instances> {
} }
#[node_macro::node_fn(CreateComputePassNode)] #[node_macro::node_fn(CreateComputePassNode)]
fn create_compute_pass_node<E: GpuExecutor>(layout: PipelineLayout<E>, executor: &'input E, output: ShaderInput<E::BufferHandle>, instances: u32) -> E::CommandBuffer { fn create_compute_pass_node<E: GpuExecutor + 'input>(layout: PipelineLayout<E>, executor: &'input E, output: ShaderInput<E::BufferHandle>, instances: u32) -> E::CommandBuffer {
executor.create_compute_pass(&layout, Some(output), instances).unwrap() executor.create_compute_pass(&layout, Some(output.into()), instances).unwrap()
} }
pub struct CreatePipelineLayoutNode<_E, EntryPoint, Bindgroup, OutputBuffer> { pub struct CreatePipelineLayoutNode<_E, EntryPoint, Bindgroup, OutputBuffer> {
@ -228,7 +245,7 @@ pub struct CreatePipelineLayoutNode<_E, EntryPoint, Bindgroup, OutputBuffer> {
} }
#[node_macro::node_fn(CreatePipelineLayoutNode<_E>)] #[node_macro::node_fn(CreatePipelineLayoutNode<_E>)]
fn create_pipeline_layout_node<_E: GpuExecutor>(shader: _E::ShaderHandle, entry_point: String, bind_group: Bindgroup<_E>, output_buffer: ShaderInput<_E::BufferHandle>) -> PipelineLayout<_E> { fn create_pipeline_layout_node<_E: GpuExecutor>(shader: _E::ShaderHandle, entry_point: String, bind_group: Bindgroup<_E>, output_buffer: Arc<ShaderInput<_E::BufferHandle>>) -> PipelineLayout<_E> {
PipelineLayout { PipelineLayout {
shader, shader,
entry_point, entry_point,

View file

@ -72,6 +72,7 @@ impl DocumentNode {
} }
NodeInput::Network(ty) => (ProtoNodeInput::Network(ty), ConstructionArgs::Nodes(vec![])), NodeInput::Network(ty) => (ProtoNodeInput::Network(ty), ConstructionArgs::Nodes(vec![])),
NodeInput::ShortCircut(ty) => (ProtoNodeInput::ShortCircut(ty), ConstructionArgs::Nodes(vec![])), NodeInput::ShortCircut(ty) => (ProtoNodeInput::ShortCircut(ty), ConstructionArgs::Nodes(vec![])),
NodeInput::Inline(inline) => (ProtoNodeInput::None, ConstructionArgs::Inline(inline)),
}; };
assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::Network(_))), "recieved non resolved parameter"); assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::Network(_))), "recieved non resolved parameter");
assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::ShortCircut(_))), "recieved non resolved parameter"); assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::ShortCircut(_))), "recieved non resolved parameter");
@ -82,6 +83,10 @@ impl DocumentNode {
&args &args
); );
// If we have one parameter of the type inline, set it as the construction args
if let &[NodeInput::Inline(ref inline)] = &self.inputs[..] {
args = ConstructionArgs::Inline(inline.clone());
}
if let ConstructionArgs::Nodes(nodes) = &mut args { if let ConstructionArgs::Nodes(nodes) = &mut args {
nodes.extend(self.inputs.iter().map(|input| match input { nodes.extend(self.inputs.iter().map(|input| match input {
NodeInput::Node { node_id, lambda, .. } => (*node_id, *lambda), NodeInput::Node { node_id, lambda, .. } => (*node_id, *lambda),
@ -176,6 +181,20 @@ pub enum NodeInput {
/// but actually consuming the provided input instead of passing it to its predecessor. /// but actually consuming the provided input instead of passing it to its predecessor.
/// See [NodeInput] docs for more explanation. /// See [NodeInput] docs for more explanation.
ShortCircut(Type), ShortCircut(Type),
Inline(InlineRust),
}
#[derive(Debug, Clone, PartialEq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct InlineRust {
pub expr: String,
pub ty: Type,
}
impl InlineRust {
pub fn new(expr: String, ty: Type) -> Self {
Self { expr, ty }
}
} }
impl NodeInput { impl NodeInput {
@ -203,6 +222,7 @@ impl NodeInput {
NodeInput::Value { exposed, .. } => *exposed, NodeInput::Value { exposed, .. } => *exposed,
NodeInput::Network(_) => false, NodeInput::Network(_) => false,
NodeInput::ShortCircut(_) => false, NodeInput::ShortCircut(_) => false,
NodeInput::Inline(_) => false,
} }
} }
pub fn ty(&self) -> Type { pub fn ty(&self) -> Type {
@ -211,6 +231,7 @@ impl NodeInput {
NodeInput::Value { tagged_value, .. } => tagged_value.ty(), NodeInput::Value { tagged_value, .. } => tagged_value.ty(),
NodeInput::Network(ty) => ty.clone(), NodeInput::Network(ty) => ty.clone(),
NodeInput::ShortCircut(ty) => ty.clone(), NodeInput::ShortCircut(ty) => ty.clone(),
NodeInput::Inline(_) => panic!("ty() called on NodeInput::Inline"),
} }
} }
} }
@ -225,7 +246,7 @@ pub enum DocumentNodeImplementation {
impl Default for DocumentNodeImplementation { impl Default for DocumentNodeImplementation {
fn default() -> Self { fn default() -> Self {
Self::Unresolved(NodeIdentifier::new("graphene_cored::ops::IdNode")) Self::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode"))
} }
} }
@ -299,10 +320,9 @@ impl NodeNetwork {
self.inputs.iter().map(move |id| self.nodes[id].inputs.get(0).map(|i| i.ty()).unwrap_or(concrete!(()))) self.inputs.iter().map(move |id| self.nodes[id].inputs.get(0).map(|i| i.ty()).unwrap_or(concrete!(())))
} }
/// An empty graph
pub fn value_network(node: DocumentNode) -> Self { pub fn value_network(node: DocumentNode) -> Self {
Self { Self {
inputs: vec![0], inputs: node.inputs.iter().filter(|input| matches!(input, NodeInput::Network(_))).map(|_| 0).collect(),
outputs: vec![NodeOutput::new(0, 0)], outputs: vec![NodeOutput::new(0, 0)],
nodes: [(0, node)].into_iter().collect(), nodes: [(0, node)].into_iter().collect(),
disabled: vec![], disabled: vec![],
@ -754,6 +774,7 @@ impl NodeNetwork {
} }
NodeInput::ShortCircut(_) => (), NodeInput::ShortCircut(_) => (),
NodeInput::Value { .. } => unreachable!("Value inputs should have been replaced with value nodes"), NodeInput::Value { .. } => unreachable!("Value inputs should have been replaced with value nodes"),
NodeInput::Inline(_) => (),
} }
} }
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()); node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into());
@ -772,14 +793,69 @@ impl NodeNetwork {
} }
} }
DocumentNodeImplementation::Unresolved(_) => (), DocumentNodeImplementation::Unresolved(_) => (),
DocumentNodeImplementation::Extract => { DocumentNodeImplementation::Extract => (),
panic!("Extract nodes should have been removed before flattening");
}
} }
assert!(!self.nodes.contains_key(&id), "Trying to insert a node into the network caused an id conflict"); assert!(!self.nodes.contains_key(&id), "Trying to insert a node into the network caused an id conflict");
self.nodes.insert(id, node); self.nodes.insert(id, node);
} }
fn remove_id_node(&mut self, id: NodeId) -> Result<(), String> {
let node = self.nodes.get(&id).ok_or_else(|| format!("Node with id {} does not exist", id))?.clone();
if let DocumentNodeImplementation::Unresolved(ident) = &node.implementation {
if ident.name == "graphene_core::ops::IdNode" {
assert_eq!(node.inputs.len(), 1, "Id node has more than one input");
if let NodeInput::Node { node_id, output_index, .. } = node.inputs[0] {
let input_node_id = node_id;
for output in self.nodes.values_mut() {
for input in &mut output.inputs {
if let NodeInput::Node {
node_id: output_node_id,
output_index: output_output_index,
..
} = input
{
if *output_node_id == id {
*output_node_id = input_node_id;
*output_output_index = output_index;
}
}
}
for NodeOutput {
ref mut node_id,
ref mut node_output_index,
} in self.outputs.iter_mut()
{
if *node_id == id {
*node_id = input_node_id;
*node_output_index = output_index;
}
}
}
}
self.nodes.remove(&id);
}
}
Ok(())
}
pub fn remove_redundant_id_nodes(&mut self) {
let id_nodes = self
.nodes
.iter()
.filter(|(_, node)| {
matches!(&node.implementation, DocumentNodeImplementation::Unresolved(ident) if ident == &NodeIdentifier::new("graphene_core::ops::IdNode"))
&& node.inputs.len() == 1
&& matches!(node.inputs[0], NodeInput::Node { .. })
})
.map(|(id, _)| *id)
.collect::<Vec<_>>();
for id in id_nodes {
if let Err(e) = self.remove_id_node(id) {
log::warn!("{}", e)
}
}
}
pub fn resolve_extract_nodes(&mut self) { pub fn resolve_extract_nodes(&mut self) {
let mut extraction_nodes = self let mut extraction_nodes = self
.nodes .nodes
@ -792,14 +868,20 @@ impl NodeNetwork {
for (_, node) in &mut extraction_nodes { for (_, node) in &mut extraction_nodes {
if let DocumentNodeImplementation::Extract = node.implementation { if let DocumentNodeImplementation::Extract = node.implementation {
assert_eq!(node.inputs.len(), 1); assert_eq!(node.inputs.len(), 1);
let NodeInput::Node { node_id, output_index, lambda } = node.inputs.pop().unwrap() else { let NodeInput::Node { node_id, output_index, .. } = node.inputs.pop().unwrap() else {
panic!("Extract node has no input"); panic!("Extract node has no input");
}; };
assert_eq!(output_index, 0); assert_eq!(output_index, 0);
assert!(lambda); // TODO: check if we can readd lambda checking
let input_node = self.nodes.get_mut(&node_id).unwrap(); let mut input_node = self.nodes.remove(&node_id).unwrap();
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into()); node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into());
node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node.clone()), false)]; for input in input_node.inputs.iter_mut() {
match input {
NodeInput::Node { .. } | NodeInput::Value { .. } => *input = NodeInput::Network(generic!(T)),
_ => (),
}
}
node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node), false)];
} }
} }
self.nodes.extend(extraction_nodes); self.nodes.extend(extraction_nodes);
@ -926,6 +1008,7 @@ mod test {
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()), implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
..Default::default() ..Default::default()
}; };
// TODO: Extend test cases to test nested network
let mut extraction_network = NodeNetwork { let mut extraction_network = NodeNetwork {
inputs: vec![], inputs: vec![],
outputs: vec![NodeOutput::new(1, 0)], outputs: vec![NodeOutput::new(1, 0)],
@ -945,7 +1028,7 @@ mod test {
..Default::default() ..Default::default()
}; };
extraction_network.resolve_extract_nodes(); extraction_network.resolve_extract_nodes();
assert_eq!(extraction_network.nodes.len(), 2); assert_eq!(extraction_network.nodes.len(), 1);
let inputs = extraction_network.nodes.get(&1).unwrap().inputs.clone(); let inputs = extraction_network.nodes.get(&1).unwrap().inputs.clone();
assert_eq!(inputs.len(), 1); assert_eq!(inputs.len(), 1);
assert!(matches!(&inputs[0], &NodeInput::Value{ tagged_value: TaggedValue::DocumentNode(ref network), ..} if network == &id_node)); assert!(matches!(&inputs[0], &NodeInput::Value{ tagged_value: TaggedValue::DocumentNode(ref network), ..} if network == &id_node));

View file

@ -189,7 +189,7 @@ impl<'a> TaggedValue {
pub fn to_primitive_string(&self) -> String { pub fn to_primitive_string(&self) -> String {
match self { match self {
TaggedValue::None => "()".to_string(), TaggedValue::None => "()".to_string(),
TaggedValue::String(x) => x.clone(), TaggedValue::String(x) => format!("\"{}\"", x),
TaggedValue::U32(x) => x.to_string(), TaggedValue::U32(x) => x.to_string(),
TaggedValue::F32(x) => x.to_string(), TaggedValue::F32(x) => x.to_string(),
TaggedValue::F64(x) => x.to_string(), TaggedValue::F64(x) => x.to_string(),

View file

@ -10,19 +10,23 @@ pub struct Compiler {}
impl Compiler { impl Compiler {
pub fn compile(&self, mut network: NodeNetwork, resolve_inputs: bool) -> impl Iterator<Item = ProtoNetwork> { pub fn compile(&self, mut network: NodeNetwork, resolve_inputs: bool) -> impl Iterator<Item = ProtoNetwork> {
let node_ids = network.nodes.keys().copied().collect::<Vec<_>>(); let node_ids = network.nodes.keys().copied().collect::<Vec<_>>();
network.resolve_extract_nodes();
println!("flattening"); println!("flattening");
for id in node_ids { for id in node_ids {
network.flatten(id); network.flatten(id);
} }
network.remove_redundant_id_nodes();
network.resolve_extract_nodes();
network.remove_dead_nodes();
let proto_networks = network.into_proto_networks(); let proto_networks = network.into_proto_networks();
proto_networks.map(move |mut proto_network| { proto_networks.map(move |mut proto_network| {
if resolve_inputs { if resolve_inputs {
println!("resolving inputs"); println!("resolving inputs");
log::debug!("resolving inputs");
proto_network.resolve_inputs(); proto_network.resolve_inputs();
} }
proto_network.reorder_ids(); proto_network.reorder_ids();
proto_network.generate_stable_node_ids(); proto_network.generate_stable_node_ids();
log::debug!("proto network: {:?}", proto_network);
proto_network proto_network
}) })
} }

View file

@ -4,8 +4,8 @@ use std::collections::{HashMap, HashSet};
use std::hash::Hash; use std::hash::Hash;
use xxhash_rust::xxh3::Xxh3; use xxhash_rust::xxh3::Xxh3;
use crate::document::value;
use crate::document::NodeId; use crate::document::NodeId;
use crate::document::{value, InlineRust};
use dyn_any::DynAny; use dyn_any::DynAny;
use graphene_core::*; use graphene_core::*;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
@ -66,6 +66,10 @@ impl core::fmt::Display for ProtoNetwork {
write_node(f, network, id.0, indent + 1)?; write_node(f, network, id.0, indent + 1)?;
} }
} }
ConstructionArgs::Inline(inline) => {
f.write_str(&"\t".repeat(indent + 1))?;
f.write_fmt(format_args!("Inline construction argument: {inline:?}"))?
}
} }
f.write_str(&"\t".repeat(indent))?; f.write_str(&"\t".repeat(indent))?;
f.write_str("}\n")?; f.write_str("}\n")?;
@ -83,6 +87,7 @@ pub enum ConstructionArgs {
Value(value::TaggedValue), Value(value::TaggedValue),
// the bool indicates whether to treat the node as lambda node // the bool indicates whether to treat the node as lambda node
Nodes(Vec<(NodeId, bool)>), Nodes(Vec<(NodeId, bool)>),
Inline(InlineRust),
} }
impl PartialEq for ConstructionArgs { impl PartialEq for ConstructionArgs {
@ -105,6 +110,7 @@ impl Hash for ConstructionArgs {
} }
} }
Self::Value(value) => value.hash(state), Self::Value(value) => value.hash(state),
Self::Inline(inline) => inline.hash(state),
} }
} }
} }
@ -114,6 +120,7 @@ impl ConstructionArgs {
match self { match self {
ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("&n{}", n.0)).collect(), ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("&n{}", n.0)).collect(),
ConstructionArgs::Value(value) => vec![value.to_primitive_string()], ConstructionArgs::Value(value) => vec![value.to_primitive_string()],
ConstructionArgs::Inline(inline) => vec![inline.expr.clone()],
} }
} }
} }
@ -453,6 +460,7 @@ impl TypingContext {
.map(|node| node.ty()) .map(|node| node.ty())
}) })
.collect::<Result<Vec<Type>, String>>()?, .collect::<Result<Vec<Type>, String>>()?,
ConstructionArgs::Inline(ref inline) => vec![inline.ty.clone()],
}; };
// Get the node input type from the proto node declaration // Get the node input type from the proto node declaration

View file

@ -10,8 +10,13 @@ license = "MIT OR Apache-2.0"
[features] [features]
memoization = ["once_cell"] memoization = ["once_cell"]
default = ["memoization"] default = ["memoization", "wgpu"]
gpu = ["graphene-core/gpu", "gpu-compiler-bin-wrapper", "compilation-client", "gpu-executor"] gpu = [
"graphene-core/gpu",
"gpu-compiler-bin-wrapper",
"compilation-client",
"gpu-executor",
]
vulkan = ["gpu", "vulkan-executor"] vulkan = ["gpu", "vulkan-executor"]
wgpu = ["gpu", "wgpu-executor"] wgpu = ["gpu", "wgpu-executor"]
quantization = ["autoquant"] quantization = ["autoquant"]

View file

@ -1,4 +1,7 @@
use glam::UVec3;
use gpu_executor::{Bindgroup, PipelineLayout, StorageBufferOptions};
use gpu_executor::{GpuExecutor, ShaderIO, ShaderInput}; use gpu_executor::{GpuExecutor, ShaderIO, ShaderInput};
use graph_craft::document::value::TaggedValue;
use graph_craft::document::*; use graph_craft::document::*;
use graph_craft::proto::*; use graph_craft::proto::*;
use graphene_core::raster::*; use graphene_core::raster::*;
@ -9,6 +12,7 @@ use wgpu_executor::NewExecutor;
use bytemuck::Pod; use bytemuck::Pod;
use core::marker::PhantomData; use core::marker::PhantomData;
use dyn_any::StaticTypeSized; use dyn_any::StaticTypeSized;
use std::sync::Arc;
pub struct GpuCompiler<TypingContext, ShaderIO> { pub struct GpuCompiler<TypingContext, ShaderIO> {
typing_context: TypingContext, typing_context: TypingContext,
@ -19,25 +23,177 @@ pub struct GpuCompiler<TypingContext, ShaderIO> {
#[node_macro::node_fn(GpuCompiler)] #[node_macro::node_fn(GpuCompiler)]
async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> compilation_client::Shader { async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> compilation_client::Shader {
let compiler = graph_craft::executor::Compiler {}; let compiler = graph_craft::executor::Compiler {};
let DocumentNodeImplementation::Network(network) = node.implementation; let DocumentNodeImplementation::Network(ref network) = node.implementation else { panic!() };
let proto_network = compiler.compile_single(network, true).unwrap(); let proto_networks: Vec<_> = compiler.compile(network.clone(), true).collect();
typing_context.update(&proto_network);
let input_types = proto_network.inputs.iter().map(|id| typing_context.get_type(*id).unwrap()).map(|node_io| node_io.output).collect();
let output_type = typing_context.get_type(proto_network.output).unwrap().output;
let bytes = compilation_client::compile(proto_network, input_types, output_type, io).await.unwrap(); for network in proto_networks.iter() {
bytes typing_context.update(network).expect("Failed to type check network");
}
// TODO: do a proper union
let input_types = proto_networks[0]
.inputs
.iter()
.map(|id| typing_context.type_of(*id).unwrap())
.map(|node_io| node_io.output.clone())
.collect();
let output_types = proto_networks.iter().map(|network| typing_context.type_of(network.output).unwrap().output.clone()).collect();
compilation_client::compile(proto_networks, input_types, output_types, io).await.unwrap()
} }
pub struct MapGpuNode<Shader> { pub struct MapGpuNode<Node> {
shader: Shader, node: Node,
} }
#[node_macro::node_fn(MapGpuNode)]
async fn map_gpu(image: ImageFrame<Color>, node: DocumentNode) -> ImageFrame<Color> {
log::debug!("Executing gpu node");
let compiler = graph_craft::executor::Compiler {};
let inner_network = NodeNetwork::value_network(node);
log::debug!("inner_network: {:?}", inner_network);
let network = NodeNetwork {
inputs: vec![], //vec![0, 1],
outputs: vec![NodeOutput::new(1, 0)],
nodes: [
DocumentNode {
name: "Slice".into(),
inputs: vec![NodeInput::Inline(InlineRust::new("i0[_global_index.x as usize]".into(), concrete![Color]))],
implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::CopiedNode".into()),
..Default::default()
},
/*DocumentNode {
name: "Index".into(),
//inputs: vec![NodeInput::Network(concrete!(UVec3))],
inputs: vec![NodeInput::Inline(InlineRust::new("i1.x as usize".into(), concrete![u32]))],
implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::CopiedNode".into()),
..Default::default()
},*/
/*
DocumentNode {
name: "GetNode".into(),
inputs: vec![NodeInput::node(1, 0), NodeInput::node(0, 0)],
implementation: DocumentNodeImplementation::Unresolved("graphene_core::storage::GetNode".into()),
..Default::default()
},*/
DocumentNode {
name: "MapNode".into(),
inputs: vec![NodeInput::node(0, 0)],
implementation: DocumentNodeImplementation::Network(inner_network),
..Default::default()
},
/*
DocumentNode {
name: "SaveNode".into(),
inputs: vec![
//NodeInput::node(0, 0),
NodeInput::Inline(InlineRust::new(
"o0[_global_index.x as usize] = i0[_global_index.x as usize]".into(),
Type::Fn(Box::new(concrete!(Color)), Box::new(concrete!(()))),
)),
],
implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into()),
..Default::default()
},
*/
]
.into_iter()
.enumerate()
.map(|(i, n)| (i as u64, n))
.collect(),
..Default::default()
};
log::debug!("compiling network");
let proto_networks = compiler.compile(network.clone(), true).collect();
log::debug!("compiling shader");
let shader = compilation_client::compile(
proto_networks,
vec![concrete!(Color)], //, concrete!(u32)],
vec![concrete!(Color)],
ShaderIO {
inputs: vec![
ShaderInput::StorageBuffer((), concrete!(Color)),
//ShaderInput::Constant(gpu_executor::GPUConstant::GlobalInvocationId),
ShaderInput::OutputBuffer((), concrete!(Color)),
],
output: ShaderInput::OutputBuffer((), concrete!(Color)),
},
)
.await
.unwrap();
//return ImageFrame::empty();
let len = image.image.data.len();
log::debug!("instances: {}", len);
let executor = NewExecutor::new().await.unwrap();
log::debug!("creating buffer");
let storage_buffer = executor
.create_storage_buffer(
image.image.data.clone(),
StorageBufferOptions {
cpu_writable: false,
gpu_writable: true,
cpu_readable: false,
storage: true,
},
)
.unwrap();
let storage_buffer = Arc::new(storage_buffer);
let output_buffer = executor.create_output_buffer(len, concrete!(Color), false).unwrap();
let output_buffer = Arc::new(output_buffer);
let readback_buffer = executor.create_output_buffer(len, concrete!(Color), true).unwrap();
let readback_buffer = Arc::new(readback_buffer);
log::debug!("created buffer");
let bind_group = Bindgroup {
buffers: vec![storage_buffer.clone()],
};
let shader = gpu_executor::Shader {
source: shader.spirv_binary.into(),
name: "gpu::eval",
io: shader.io,
};
log::debug!("loading shader");
log::debug!("shader: {:?}", shader.source);
let shader = executor.load_shader(shader).unwrap();
log::debug!("loaded shader");
let pipeline = PipelineLayout {
shader,
entry_point: "eval".to_string(),
bind_group,
output_buffer: output_buffer.clone(),
};
log::debug!("created pipeline");
let compute_pass = executor.create_compute_pass(&pipeline, Some(readback_buffer.clone()), len.min(65535) as u32).unwrap();
executor.execute_compute_pipeline(compute_pass).unwrap();
log::debug!("executed pipeline");
log::debug!("reading buffer");
let result = executor.read_output_buffer(readback_buffer).await.unwrap();
let colors = bytemuck::pod_collect_to_vec::<u8, Color>(result.as_slice());
ImageFrame {
image: Image {
data: colors,
width: image.image.width,
height: image.image.height,
},
transform: image.transform,
}
/*
let executor: GpuExecutor = GpuExecutor::new(Context::new().await.unwrap(), shader.into(), "gpu::eval".into()).unwrap();
let data: Vec<_> = input.into_iter().collect();
let result = executor.execute(Box::new(data)).unwrap();
let result = dyn_any::downcast::<Vec<_O>>(result).unwrap();
*result
*/
}
/*
#[node_macro::node_fn(MapGpuNode)] #[node_macro::node_fn(MapGpuNode)]
async fn map_gpu(inputs: Vec<ShaderInput<<NewExecutor as GpuExecutor>::BufferHandle>>, shader: &'any_input compilation_client::Shader) { async fn map_gpu(inputs: Vec<ShaderInput<<NewExecutor as GpuExecutor>::BufferHandle>>, shader: &'any_input compilation_client::Shader) {
use graph_craft::executor::Executor; use graph_craft::executor::Executor;
let executor = NewExecutor::new().unwrap(); let executor = NewExecutor::new().unwrap();
for input in shader.inputs.iter() { for input in shader.io.inputs.iter() {
let buffer = executor.create_storage_buffer(&self, data, options)
let buffer = executor.create_buffer(input.size).unwrap(); let buffer = executor.create_buffer(input.size).unwrap();
executor.write_buffer(buffer, input.data).unwrap(); executor.write_buffer(buffer, input.data).unwrap();
} }
@ -74,6 +230,7 @@ fn map_gpu_single_image(input: Image<Color>, node: String) -> Image<Color> {
inputs: vec![NodeInput::Network(concrete!(Color))], inputs: vec![NodeInput::Network(concrete!(Color))],
implementation: DocumentNodeImplementation::Unresolved(identifier), implementation: DocumentNodeImplementation::Unresolved(identifier),
metadata: DocumentNodeMetadata::default(), metadata: DocumentNodeMetadata::default(),
..Default::default()
}, },
)] )]
.into_iter() .into_iter()
@ -85,3 +242,4 @@ fn map_gpu_single_image(input: Image<Color>, node: String) -> Image<Color> {
let data = map_node.eval(input.data.clone()); let data = map_node.eval(input.data.clone());
Image { data, ..input } Image { data, ..input }
} }
*/

View file

@ -203,6 +203,7 @@ impl BorrowTree {
let node = unsafe { node.erase_lifetime() }; let node = unsafe { node.erase_lifetime() };
self.store_node(Arc::new(node.into()), id); self.store_node(Arc::new(node.into()), id);
} }
ConstructionArgs::Inline(_) => unimplemented!("Inline nodes are not supported yet"),
ConstructionArgs::Nodes(ids) => { ConstructionArgs::Nodes(ids) => {
let ids: Vec<_> = ids.iter().map(|(id, _)| *id).collect(); let ids: Vec<_> = ids.iter().map(|(id, _)| *id).collect();
let construction_nodes = self.node_refs(&ids); let construction_nodes = self.node_refs(&ids);

View file

@ -131,7 +131,7 @@ mod tests {
0, 0,
DocumentNode { DocumentNode {
name: "id".into(), name: "id".into(),
inputs: vec![NodeInput::Network(concrete!(u32))], inputs: vec![NodeInput::ShortCircut(concrete!(u32))],
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")), implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")),
..Default::default() ..Default::default()
}, },

View file

@ -1,5 +1,6 @@
use glam::{DAffine2, DVec2}; use glam::{DAffine2, DVec2};
use graph_craft::document::DocumentNode;
use graphene_core::ops::IdNode; use graphene_core::ops::IdNode;
use graphene_core::vector::VectorData; use graphene_core::vector::VectorData;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
@ -219,7 +220,7 @@ fn node_registry() -> HashMap<NodeIdentifier, HashMap<NodeIOTypes, NodeConstruct
|args| { |args| {
Box::pin(async move { Box::pin(async move {
let document_node: DowncastBothNode<(), DocumentNode> = DowncastBothNode::new(args[0]); let document_node: DowncastBothNode<(), DocumentNode> = DowncastBothNode::new(args[0]);
let document_node = ClonedNode::new(document_node.eval(()).await); //let document_node = ClonedNode::new(document_node.eval(()));
let node = graphene_std::executor::MapGpuNode::new(document_node); let node = graphene_std::executor::MapGpuNode::new(document_node);
let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node)); let any: DynAnyNode<ImageFrame<Color>, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node));
Box::pin(any) as TypeErasedPinned Box::pin(any) as TypeErasedPinned

View file

@ -23,7 +23,7 @@ base64 = "0.13"
bytemuck = {version = "1.8" } bytemuck = {version = "1.8" }
anyhow = "1.0.66" anyhow = "1.0.66"
wgpu = { version = "0.14.2", features = ["spirv"] } wgpu = { version = "0.16", features = ["spirv"] }
spirv = "0.2.0" spirv = "0.2.0"
futures-intrusive = "0.5.0" futures-intrusive = "0.5.0"
futures = "0.3.25" futures = "0.3.25"

View file

@ -11,7 +11,7 @@ pub struct Context {
impl Context { impl Context {
pub async fn new() -> Option<Self> { pub async fn new() -> Option<Self> {
// Instantiates instance of WebGPU // Instantiates instance of WebGPU
let instance = wgpu::Instance::new(wgpu::Backends::all()); let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::default());
// `request_adapter` instantiates the general connection to the GPU // `request_adapter` instantiates the general connection to the GPU
let adapter = instance.request_adapter(&wgpu::RequestAdapterOptions::default()).await?; let adapter = instance.request_adapter(&wgpu::RequestAdapterOptions::default()).await?;

View file

@ -9,6 +9,7 @@ use graph_craft::Type;
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use futures::Future; use futures::Future;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use wgpu::util::DeviceExt; use wgpu::util::DeviceExt;
use wgpu::{Buffer, BufferDescriptor, CommandBuffer, ShaderModule}; use wgpu::{Buffer, BufferDescriptor, CommandBuffer, ShaderModule};
@ -42,8 +43,11 @@ impl gpu_executor::GpuExecutor for NewExecutor {
fn create_storage_buffer<T: ToStorageBuffer>(&self, data: T, options: StorageBufferOptions) -> Result<ShaderInput<Self::BufferHandle>> { fn create_storage_buffer<T: ToStorageBuffer>(&self, data: T, options: StorageBufferOptions) -> Result<ShaderInput<Self::BufferHandle>> {
let bytes = data.to_bytes(); let bytes = data.to_bytes();
let mut usage = wgpu::BufferUsages::STORAGE; let mut usage = wgpu::BufferUsages::empty();
if options.storage {
usage |= wgpu::BufferUsages::STORAGE;
}
if options.gpu_writable { if options.gpu_writable {
usage |= wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST; usage |= wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST;
} }
@ -54,15 +58,17 @@ impl gpu_executor::GpuExecutor for NewExecutor {
usage |= wgpu::BufferUsages::MAP_WRITE | wgpu::BufferUsages::COPY_SRC; usage |= wgpu::BufferUsages::MAP_WRITE | wgpu::BufferUsages::COPY_SRC;
} }
log::debug!("Creating storage buffer with usage {:?} and len: {}", usage, bytes.len());
let buffer = self.context.device.create_buffer_init(&wgpu::util::BufferInitDescriptor { let buffer = self.context.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None, label: None,
contents: bytes.as_ref(), contents: bytes.as_ref(),
usage, usage,
}); });
Ok(ShaderInput::StorageBuffer(buffer, Type::new::<T>())) Ok(ShaderInput::StorageBuffer(buffer, data.ty()))
} }
fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result<ShaderInput<Self::BufferHandle>> { fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result<ShaderInput<Self::BufferHandle>> {
log::debug!("Creating output buffer with len: {}", len);
let create_buffer = |usage| { let create_buffer = |usage| {
Ok::<_, anyhow::Error>(self.context.device.create_buffer(&BufferDescriptor { Ok::<_, anyhow::Error>(self.context.device.create_buffer(&BufferDescriptor {
label: None, label: None,
@ -72,13 +78,12 @@ impl gpu_executor::GpuExecutor for NewExecutor {
})) }))
}; };
let buffer = match cpu_readable { let buffer = match cpu_readable {
true => ShaderInput::ReadBackBuffer(create_buffer(wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ)?, ty), true => ShaderInput::ReadBackBuffer(create_buffer(wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ)?, ty),
false => ShaderInput::OutputBuffer(create_buffer(wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC)?, ty), false => ShaderInput::OutputBuffer(create_buffer(wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC)?, ty),
}; };
Ok(buffer) Ok(buffer)
} }
fn create_compute_pass(&self, layout: &gpu_executor::PipelineLayout<Self>, read_back: Option<Arc<ShaderInput<Self::BufferHandle>>>, instances: u32) -> Result<CommandBuffer> {
fn create_compute_pass(&self, layout: &gpu_executor::PipelineLayout<Self>, read_back: Option<ShaderInput<Self::BufferHandle>>, instances: u32) -> Result<CommandBuffer> {
let compute_pipeline = self.context.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { let compute_pipeline = self.context.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None, label: None,
layout: None, layout: None,
@ -116,10 +121,13 @@ impl gpu_executor::GpuExecutor for NewExecutor {
} }
// Sets adds copy operation to command encoder. // Sets adds copy operation to command encoder.
// Will copy data from storage buffer on GPU to staging buffer on CPU. // Will copy data from storage buffer on GPU to staging buffer on CPU.
if let Some(ShaderInput::ReadBackBuffer(output, ty)) = read_back { if let Some(buffer) = read_back {
let ShaderInput::ReadBackBuffer(output, ty) = buffer.as_ref() else {
bail!("Tried to read back from a non read back buffer");
};
let size = output.size(); let size = output.size();
assert_eq!(size, layout.output_buffer.buffer().unwrap().size()); assert_eq!(size, layout.output_buffer.buffer().unwrap().size());
assert_eq!(ty, layout.output_buffer.ty()); assert_eq!(ty, &layout.output_buffer.ty());
encoder.copy_buffer_to_buffer( encoder.copy_buffer_to_buffer(
layout.output_buffer.buffer().ok_or_else(|| anyhow::anyhow!("Tried to use an non buffer as the shader output"))?, layout.output_buffer.buffer().ok_or_else(|| anyhow::anyhow!("Tried to use an non buffer as the shader output"))?,
0, 0,
@ -143,9 +151,9 @@ impl gpu_executor::GpuExecutor for NewExecutor {
Ok(()) Ok(())
} }
fn read_output_buffer(&self, buffer: ShaderInput<Self::BufferHandle>) -> Result<Pin<Box<dyn Future<Output = Result<Vec<u8>>>>>> { fn read_output_buffer(&self, buffer: Arc<ShaderInput<Self::BufferHandle>>) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>>>> {
if let ShaderInput::ReadBackBuffer(buffer, _) = buffer { let future = Box::pin(async move {
let future = Box::pin(async move { if let ShaderInput::ReadBackBuffer(buffer, _) = buffer.as_ref() {
let buffer_slice = buffer.slice(..); let buffer_slice = buffer.slice(..);
// Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished. // Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished.
@ -175,17 +183,17 @@ impl gpu_executor::GpuExecutor for NewExecutor {
} else { } else {
bail!("failed to run compute on gpu!") bail!("failed to run compute on gpu!")
} }
}); } else {
Ok(future) bail!("Tried to read a non readback buffer")
} else { }
bail!("Tried to read a non readback buffer") });
} future
} }
} }
impl NewExecutor { impl NewExecutor {
pub fn new() -> Option<Self> { pub async fn new() -> Option<Self> {
let context = Context::new_sync()?; let context = Context::new().await?;
Some(Self { context }) Some(Self { context })
} }
} }