diff --git a/Cargo.lock b/Cargo.lock index f9e51c47a..0acc5fcc5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5916,6 +5916,7 @@ dependencies = [ "uv-pep440", "uv-platform-tags", "uv-static", + "walkdir", ] [[package]] diff --git a/crates/uv-torch/Cargo.toml b/crates/uv-torch/Cargo.toml index d173c6ede..fdaa4653e 100644 --- a/crates/uv-torch/Cargo.toml +++ b/crates/uv-torch/Cargo.toml @@ -24,6 +24,7 @@ serde = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } url = { workspace = true } +walkdir = { workspace = true } [lints] workspace = true diff --git a/crates/uv-torch/src/accelerator.rs b/crates/uv-torch/src/accelerator.rs index 3165bd4c5..0bcdca246 100644 --- a/crates/uv-torch/src/accelerator.rs +++ b/crates/uv-torch/src/accelerator.rs @@ -1,6 +1,8 @@ +use std::path::Path; use std::str::FromStr; use tracing::debug; +use walkdir::WalkDir; use uv_pep440::Version; use uv_static::EnvVars; @@ -13,6 +15,10 @@ pub enum AcceleratorError { Version(#[from] uv_pep440::VersionParseError), #[error(transparent)] Utf8(#[from] std::string::FromUtf8Error), + #[error(transparent)] + ParseInt(#[from] std::num::ParseIntError), + #[error(transparent)] + WalkDir(#[from] walkdir::Error), #[error("Unknown AMD GPU architecture: {0}")] UnknownAmdGpuArchitecture(String), } @@ -30,6 +36,10 @@ pub enum Accelerator { Amd { gpu_architecture: AmdGpuArchitecture, }, + /// The Intel GPU (XPU). + /// + /// Currently, Intel GPUs do not depend on a driver/toolkit version at this level. + Xpu, } impl std::fmt::Display for Accelerator { @@ -37,21 +47,28 @@ impl std::fmt::Display for Accelerator { match self { Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"), Self::Amd { gpu_architecture } => write!(f, "AMD {gpu_architecture}"), + Self::Xpu => write!(f, "Intel GPU (XPU) detected"), } } } impl Accelerator { - /// Detect the CUDA driver version from the system. + /// Detect the GPU driver/architecture version from the system. /// /// Query, in order: /// 1. The `UV_CUDA_DRIVER_VERSION` environment variable. /// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable. - /// 2. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`). - /// 3. `/proc/driver/nvidia/version`, which contains the driver version among other information. - /// 4. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`. - /// 5. `rocm_agent_enumerator`, which lists the AMD GPU architectures. + /// 3. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`). + /// 4. `/proc/driver/nvidia/version`, which contains the driver version among other information. + /// 5. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`. + /// 6. `rocm_agent_enumerator`, which lists the AMD GPU architectures. + /// 7. `/sys/bus/pci/devices`, filtering for the Intel GPU via PCI. pub fn detect() -> Result, AcceleratorError> { + // Constants used for PCI device detection + const PCI_BASE_CLASS_MASK: u32 = 0x00ff_0000; + const PCI_BASE_CLASS_DISPLAY: u32 = 0x0003_0000; + const PCI_VENDOR_ID_INTEL: u32 = 0x8086; + // Read from `UV_CUDA_DRIVER_VERSION`. if let Ok(driver_version) = std::env::var(EnvVars::UV_CUDA_DRIVER_VERSION) { let driver_version = Version::from_str(&driver_version)?; @@ -150,6 +167,39 @@ impl Accelerator { } } + // Read from `/sys/bus/pci/devices` to filter for Intel GPU via PCI. + match WalkDir::new("/sys/bus/pci/devices") + .min_depth(1) + .max_depth(1) + .into_iter() + .collect::, _>>() + { + Ok(entries) => { + for entry in entries { + match parse_pci_device_ids(entry.path()) { + Ok((class, vendor)) => { + if (class & PCI_BASE_CLASS_MASK) == PCI_BASE_CLASS_DISPLAY + && vendor == PCI_VENDOR_ID_INTEL + { + debug!("Detected Intel GPU from PCI: vendor=0x{:04x}", vendor); + return Ok(Some(Self::Xpu)); + } + } + Err(e) => { + debug!("Failed to parse PCI device IDs: {e}"); + } + } + } + } + Err(e) + if e.io_error() + .is_some_and(|io| io.kind() == std::io::ErrorKind::NotFound) => {} + Err(e) => { + debug!("Failed to read PCI device directory with WalkDir: {e}"); + return Err(e.into()); + } + } + debug!("Failed to detect GPU driver version"); Ok(None) @@ -180,6 +230,22 @@ fn parse_proc_driver_nvidia_version(content: &str) -> Result, Ac Ok(Some(driver_version)) } +/// Reads and parses the PCI class and vendor ID from a given device path under `/sys/bus/pci/devices`. +fn parse_pci_device_ids(device_path: &Path) -> Result<(u32, u32), AcceleratorError> { + // Parse, e.g.: + // ```text + // - `class`: a hexadecimal string such as `0x030000` + // - `vendor`: a hexadecimal string such as `0x8086` + // ``` + let class_content = fs_err::read_to_string(device_path.join("class"))?; + let pci_class = u32::from_str_radix(class_content.trim().trim_start_matches("0x"), 16)?; + + let vendor_content = fs_err::read_to_string(device_path.join("vendor"))?; + let pci_vendor = u32::from_str_radix(vendor_content.trim().trim_start_matches("0x"), 16)?; + + Ok((pci_class, pci_vendor)) +} + /// A GPU architecture for AMD GPUs. /// /// See: diff --git a/crates/uv-torch/src/backend.rs b/crates/uv-torch/src/backend.rs index 0f2b72077..5ad71b385 100644 --- a/crates/uv-torch/src/backend.rs +++ b/crates/uv-torch/src/backend.rs @@ -185,6 +185,8 @@ pub enum TorchStrategy { os: Os, gpu_architecture: AmdGpuArchitecture, }, + /// Select the appropriate PyTorch index based on the operating system and Intel GPU presence. + Xpu { os: Os }, /// Use the specified PyTorch index. Backend(TorchBackend), } @@ -202,6 +204,7 @@ impl TorchStrategy { os: os.clone(), gpu_architecture, }), + Some(Accelerator::Xpu) => Ok(Self::Xpu { os: os.clone() }), None => Ok(Self::Backend(TorchBackend::Cpu)), }, TorchMode::Cpu => Ok(Self::Backend(TorchBackend::Cpu)), @@ -347,9 +350,27 @@ impl TorchStrategy { Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url()))) } }, - TorchStrategy::Backend(backend) => { - Either::Right(Either::Right(std::iter::once(backend.index_url()))) - } + TorchStrategy::Xpu { os } => match os { + Os::Manylinux { .. } => Either::Right(Either::Right(Either::Left( + std::iter::once(TorchBackend::Xpu.index_url()), + ))), + Os::Windows + | Os::Musllinux { .. } + | Os::Macos { .. } + | Os::FreeBsd { .. } + | Os::NetBsd { .. } + | Os::OpenBsd { .. } + | Os::Dragonfly { .. } + | Os::Illumos { .. } + | Os::Haiku { .. } + | Os::Android { .. } + | Os::Pyodide { .. } => { + Either::Right(Either::Left(std::iter::once(TorchBackend::Cpu.index_url()))) + } + }, + TorchStrategy::Backend(backend) => Either::Right(Either::Right(Either::Right( + std::iter::once(backend.index_url()), + ))), } } } diff --git a/docs/guides/integration/pytorch.md b/docs/guides/integration/pytorch.md index a90ebeb6b..f060cbb6b 100644 --- a/docs/guides/integration/pytorch.md +++ b/docs/guides/integration/pytorch.md @@ -444,10 +444,10 @@ $ # With an environment variable. $ UV_TORCH_BACKEND=auto uv pip install torch ``` -When enabled, uv will query for the installed CUDA driver and AMD GPU versions then use the -most-compatible PyTorch index for all relevant packages (e.g., `torch`, `torchvision`, etc.). If no -such GPU is found, uv will fall back to the CPU-only index. uv will continue to respect existing -index configuration for any packages outside the PyTorch ecosystem. +When enabled, uv will query for the installed CUDA driver, AMD GPU versions and Intel GPU presence +then use the most-compatible PyTorch index for all relevant packages (e.g., `torch`, `torchvision`, +etc.). If no such GPU is found, uv will fall back to the CPU-only index. uv will continue to respect +existing index configuration for any packages outside the PyTorch ecosystem. You can also select a specific backend (e.g., CUDA 12.6) with `--torch-backend=cu126` (or `UV_TORCH_BACKEND=cu126`): @@ -460,4 +460,12 @@ $ # With an environment variable. $ UV_TORCH_BACKEND=cu126 uv pip install torch torchvision ``` +On Windows, Intel GPU (XPU) is not automatically selected with `--torch-backend=auto`, but you can +manually specify it using `--torch-backend=xpu`: + +```shell +$ # Manual selection for Intel GPU. +$ uv pip install torch torchvision --torch-backend=xpu +``` + At present, `--torch-backend` is only available in the `uv pip` interface.