mirror of
https://github.com/astral-sh/uv.git
synced 2025-08-04 19:08:04 +00:00
Automatically infer the PyTorch index via --torch-backend=auto
(#12070)
## Summary This is a prototype that I'm considering shipping under `--preview`, based on [`light-the-torch`](https://github.com/pmeier/light-the-torch). `light-the-torch` patches pip to pull PyTorch packages from the PyTorch indexes automatically. And, in particular, `light-the-torch` will query the installed CUDA drivers to determine which indexes are compatible with your system. This PR implements equivalent behavior under `--torch-backend auto`, though you can also set `--torch-backend cpu`, etc. for convenience. When enabled, the registry client will fetch from the appropriate PyTorch index when it sees a package from the PyTorch ecosystem (and ignore any other configured indexes, _unless_ the package is explicitly pinned to a different index). Right now, this is only implemented in the `uv pip` CLI, since it doesn't quite fit into the lockfile APIs given that it relies on feature detection on the currently-running machine. ## Test Plan On macOS, you can test this with (e.g.): ```shell UV_TORCH_BACKEND=auto UV_CUDA_DRIVER_VERSION=450.80.2 cargo run \ pip install torch --python-platform linux --python-version 3.12 ``` On a GPU-enabled EC2 machine: ```shell ubuntu@ip-172-31-47-149:~/uv$ UV_TORCH_BACKEND=auto cargo run pip install torch -v Finished `dev` profile [unoptimized + debuginfo] target(s) in 0.31s Running `target/debug/uv pip install torch -v` DEBUG uv 0.6.6 (e95ca063b 2025-03-14) DEBUG Searching for default Python interpreter in virtual environments DEBUG Found `cpython-3.13.0-linux-x86_64-gnu` at `/home/ubuntu/uv/.venv/bin/python3` (virtual environment) DEBUG Using Python 3.13.0 environment at: .venv DEBUG Acquired lock for `.venv` DEBUG At least one requirement is not satisfied: torch warning: The `--torch-backend` setting is experimental and may change without warning. Pass `--preview` to disable this warning. DEBUG Detected CUDA driver version from `/sys/module/nvidia/version`: 550.144.3 ... ```
This commit is contained in:
parent
e40c551b80
commit
5173b59b50
31 changed files with 1289 additions and 29 deletions
|
@ -678,4 +678,11 @@ impl EnvVars {
|
|||
///
|
||||
/// This is a quasi-standard variable, described, e.g., in `ncurses(3x)`.
|
||||
pub const COLUMNS: &'static str = "COLUMNS";
|
||||
|
||||
/// The CUDA driver version to assume when inferring the PyTorch backend.
|
||||
#[attr_hidden]
|
||||
pub const UV_CUDA_DRIVER_VERSION: &'static str = "UV_CUDA_DRIVER_VERSION";
|
||||
|
||||
/// Equivalent to the `--torch-backend` command-line argument (e.g., `cpu`, `cu126`, or `auto`).
|
||||
pub const UV_TORCH_BACKEND: &'static str = "UV_TORCH_BACKEND";
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue