Desktop: Add rudimentary support for custom WGPU adapter selection (#3201)

* rudimentary custom wgpu adapter selection

* WgpuContextBuilder

* wasm fix

* fix wasm warnings

* Clean up

* Review suggestions

* fix
This commit is contained in:
Timon 2025-09-25 14:38:26 +00:00 committed by GitHub
parent 4e47b5db93
commit ed22e6a63d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 181 additions and 54 deletions

View file

@ -1,49 +1,64 @@
use std::sync::Arc;
use wgpu::{Device, Instance, Queue};
use wgpu::{Adapter, Backends, Device, Features, Instance, Queue};
#[derive(Debug, Clone)]
pub struct Context {
pub device: Arc<Device>,
pub queue: Arc<Queue>,
pub instance: Arc<Instance>,
pub adapter: Arc<wgpu::Adapter>,
pub adapter: Arc<Adapter>,
}
impl Context {
pub async fn new() -> Option<Self> {
// Instantiates instance of WebGPU
let instance_descriptor = wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
..Default::default()
};
let instance = Instance::new(&instance_descriptor);
ContextBuilder::new().build().await
}
}
let adapter_options = wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
};
// `request_adapter` instantiates the general connection to the GPU
let adapter = instance.request_adapter(&adapter_options).await.ok()?;
let required_limits = adapter.limits();
// `request_device` instantiates the feature specific connection to the GPU, defining some parameters,
// `features` being the available features.
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor {
label: None,
#[cfg(target_family = "wasm")]
required_features: wgpu::Features::empty(),
#[cfg(not(target_family = "wasm"))]
required_features: wgpu::Features::PUSH_CONSTANTS,
required_limits,
memory_hints: Default::default(),
trace: wgpu::Trace::Off,
})
.await
.ok()?;
Some(Self {
#[derive(Default)]
pub struct ContextBuilder {
backends: Backends,
features: Features,
}
impl ContextBuilder {
pub fn new() -> Self {
Self {
backends: Backends::all(),
features: Features::empty(),
}
}
pub fn with_backends(mut self, backends: Backends) -> Self {
self.backends = backends;
self
}
pub fn with_features(mut self, features: Features) -> Self {
self.features = features;
self
}
}
#[cfg(not(target_family = "wasm"))]
impl ContextBuilder {
pub async fn build(self) -> Option<Context> {
self.build_with_adapter_selection_inner(None::<fn(&[Adapter]) -> Option<usize>>).await
}
pub async fn build_with_adapter_selection<S>(self, select: S) -> Option<Context>
where
S: Fn(&[Adapter]) -> Option<usize>,
{
self.build_with_adapter_selection_inner(Some(select)).await
}
pub async fn available_adapters_fmt(&self) -> impl std::fmt::Display {
let instance = self.build_instance();
fmt::AvailableAdaptersFormatter(instance.enumerate_adapters(self.backends))
}
}
#[cfg(target_family = "wasm")]
impl ContextBuilder {
pub async fn build(self) -> Option<Context> {
let instance = self.build_instance();
let adapter = self.request_adapter(&instance).await?;
let (device, queue) = self.request_device(&adapter).await?;
Some(Context {
device: Arc::new(device),
queue: Arc::new(queue),
adapter: Arc::new(adapter),
@ -51,3 +66,86 @@ impl Context {
})
}
}
impl ContextBuilder {
fn build_instance(&self) -> Instance {
Instance::new(&wgpu::InstanceDescriptor {
backends: self.backends,
..Default::default()
})
}
async fn request_adapter(&self, instance: &Instance) -> Option<Adapter> {
let request_adapter_options = wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
};
instance.request_adapter(&request_adapter_options).await.ok()
}
async fn request_device(&self, adapter: &Adapter) -> Option<(Device, Queue)> {
let device_descriptor = wgpu::DeviceDescriptor {
label: None,
required_features: self.features,
required_limits: adapter.limits(),
memory_hints: Default::default(),
trace: wgpu::Trace::Off,
};
adapter.request_device(&device_descriptor).await.ok()
}
}
#[cfg(not(target_family = "wasm"))]
impl ContextBuilder {
async fn build_with_adapter_selection_inner<S>(self, select: Option<S>) -> Option<Context>
where
S: Fn(&[Adapter]) -> Option<usize>,
{
let instance = self.build_instance();
let selected_adapter = if let Some(select) = select {
self.select_adapter(&instance, select)
} else if cfg!(target_os = "windows") {
self.select_adapter(&instance, |adapters: &[Adapter]| adapters.iter().position(|a| a.get_info().backend == wgpu::Backend::Dx12))
} else {
None
};
let adapter = if let Some(adapter) = selected_adapter { adapter } else { self.request_adapter(&instance).await? };
let (device, queue) = self.request_device(&adapter).await?;
Some(Context {
device: Arc::new(device),
queue: Arc::new(queue),
adapter: Arc::new(adapter),
instance: Arc::new(instance),
})
}
fn select_adapter<S>(&self, instance: &Instance, select: S) -> Option<Adapter>
where
S: Fn(&[Adapter]) -> Option<usize>,
{
let mut adapters = instance.enumerate_adapters(self.backends);
let selected_index = select(&adapters)?;
if selected_index >= adapters.len() {
return None;
}
Some(adapters.remove(selected_index))
}
}
#[cfg(not(target_family = "wasm"))]
mod fmt {
use super::*;
pub(super) struct AvailableAdaptersFormatter(pub(super) Vec<Adapter>);
impl std::fmt::Display for AvailableAdaptersFormatter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for (i, adapter) in self.0.iter().enumerate() {
let info = adapter.get_info();
writeln!(
f,
"[{}] {:?} {:?} (Name: {}, Driver: {}, Device: {})",
i, info.backend, info.device_type, info.name, info.driver, info.device,
)?;
}
Ok(())
}
}
}

View file

@ -4,7 +4,6 @@ pub mod texture_upload;
use crate::shader_runtime::ShaderRuntime;
use anyhow::Result;
pub use context::Context;
use dyn_any::StaticType;
use futures::lock::Mutex;
use glam::UVec2;
@ -16,9 +15,14 @@ use vello::{AaConfig, AaSupport, RenderParams, Renderer, RendererOptions, Scene}
use wgpu::util::TextureBlitter;
use wgpu::{Origin3d, SurfaceConfiguration, TextureAspect};
pub use context::Context as WgpuContext;
pub use context::ContextBuilder as WgpuContextBuilder;
pub use wgpu::Backends as WgpuBackends;
pub use wgpu::Features as WgpuFeatures;
#[derive(dyn_any::DynAny)]
pub struct WgpuExecutor {
pub context: Context,
pub context: WgpuContext,
vello_renderer: Mutex<Renderer>,
pub shader_runtime: ShaderRuntime,
}
@ -182,10 +186,10 @@ impl WgpuExecutor {
impl WgpuExecutor {
pub async fn new() -> Option<Self> {
Self::with_context(Context::new().await?)
Self::with_context(WgpuContext::new().await?)
}
pub fn with_context(context: Context) -> Option<Self> {
pub fn with_context(context: WgpuContext) -> Option<Self> {
let vello_renderer = Renderer::new(
&context.device,
RendererOptions {

View file

@ -1,4 +1,4 @@
use crate::Context;
use crate::WgpuContext;
use crate::shader_runtime::per_pixel_adjust_runtime::PerPixelAdjustShaderRuntime;
pub mod per_pixel_adjust_runtime;
@ -6,12 +6,12 @@ pub mod per_pixel_adjust_runtime;
pub const FULLSCREEN_VERTEX_SHADER_NAME: &str = "fullscreen_vertexfullscreen_vertex";
pub struct ShaderRuntime {
context: Context,
context: WgpuContext,
per_pixel_adjust: PerPixelAdjustShaderRuntime,
}
impl ShaderRuntime {
pub fn new(context: &Context) -> Self {
pub fn new(context: &WgpuContext) -> Self {
Self {
context: context.clone(),
per_pixel_adjust: PerPixelAdjustShaderRuntime::new(),

View file

@ -1,4 +1,4 @@
use crate::Context;
use crate::WgpuContext;
use crate::shader_runtime::{FULLSCREEN_VERTEX_SHADER_NAME, ShaderRuntime};
use futures::lock::Mutex;
use graphene_core::raster_types::{GPU, Raster};
@ -31,7 +31,7 @@ impl ShaderRuntime {
let mut cache = self.per_pixel_adjust.pipeline_cache.lock().await;
let pipeline = cache
.entry(shaders.fragment_shader_name.to_owned())
.or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, &shaders));
.or_insert_with(|| PerPixelAdjustGraphicsPipeline::new(&self.context, shaders));
let arg_buffer = args.map(|args| {
let device = &self.context.device;
@ -58,7 +58,7 @@ pub struct PerPixelAdjustGraphicsPipeline {
}
impl PerPixelAdjustGraphicsPipeline {
pub fn new(context: &Context, info: &Shaders) -> Self {
pub fn new(context: &WgpuContext, info: &Shaders) -> Self {
let device = &context.device;
let name = info.fragment_shader_name.to_owned();
@ -67,7 +67,7 @@ impl PerPixelAdjustGraphicsPipeline {
// TODO workaround to naga removing `:`
let fragment_name = fragment_name.replace(":", "");
let shader_module = device.create_shader_module(ShaderModuleDescriptor {
label: Some(&format!("PerPixelAdjust {} wgsl shader", name)),
label: Some(&format!("PerPixelAdjust {name} wgsl shader")),
source: ShaderSource::Wgsl(Cow::Borrowed(info.wgsl_shader)),
});
@ -107,16 +107,16 @@ impl PerPixelAdjustGraphicsPipeline {
}]
};
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some(&format!("PerPixelAdjust {} PipelineLayout", name)),
label: Some(&format!("PerPixelAdjust {name} PipelineLayout")),
bind_group_layouts: &[&device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some(&format!("PerPixelAdjust {} BindGroupLayout 0", name)),
label: Some(&format!("PerPixelAdjust {name} BindGroupLayout 0")),
entries,
})],
push_constant_ranges: &[],
});
let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor {
label: Some(&format!("PerPixelAdjust {} Pipeline", name)),
label: Some(&format!("PerPixelAdjust {name} Pipeline")),
layout: Some(&pipeline_layout),
vertex: VertexState {
module: &shader_module,
@ -155,7 +155,7 @@ impl PerPixelAdjustGraphicsPipeline {
}
}
pub fn dispatch(&self, context: &Context, textures: Table<Raster<GPU>>, arg_buffer: Option<Buffer>) -> Table<Raster<GPU>> {
pub fn dispatch(&self, context: &WgpuContext, textures: Table<Raster<GPU>>, arg_buffer: Option<Buffer>) -> Table<Raster<GPU>> {
assert_eq!(self.has_uniform, arg_buffer.is_some());
let device = &context.device;
let name = self.name.as_str();