Add XPU to --torch-backend (#14172)

## Summary

Like ROCm, no auto-detection for now.
This commit is contained in:
Charlie Marsh 2025-06-20 20:33:20 -04:00 committed by GitHub
parent 0133bcc8ca
commit e59835d50c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 27 additions and 0 deletions

View file

@ -86,4 +86,10 @@ mod tests {
let url = DisplaySafeUrl::parse("https://download.pytorch.org/whl/cpu").unwrap();
assert_eq!(SystemDependency::from_index(&url), None);
}
#[test]
fn pytorch_xpu() {
let url = DisplaySafeUrl::parse("https://download.pytorch.org/whl/xpu").unwrap();
assert_eq!(SystemDependency::from_index(&url), None);
}
}

View file

@ -171,6 +171,8 @@ pub enum TorchMode {
#[serde(rename = "rocm4.0.1")]
#[cfg_attr(feature = "clap", clap(name = "rocm4.0.1"))]
Rocm401,
/// Use the PyTorch index for Intel XPU.
Xpu,
}
/// The strategy to use when determining the appropriate PyTorch index.
@ -237,6 +239,7 @@ impl TorchStrategy {
TorchMode::Rocm42 => Ok(Self::Backend(TorchBackend::Rocm42)),
TorchMode::Rocm41 => Ok(Self::Backend(TorchBackend::Rocm41)),
TorchMode::Rocm401 => Ok(Self::Backend(TorchBackend::Rocm401)),
TorchMode::Xpu => Ok(Self::Backend(TorchBackend::Xpu)),
}
}
@ -356,6 +359,7 @@ pub enum TorchBackend {
Rocm42,
Rocm41,
Rocm401,
Xpu,
}
impl TorchBackend {
@ -403,6 +407,7 @@ impl TorchBackend {
Self::Rocm42 => &ROCM42_INDEX_URL,
Self::Rocm41 => &ROCM41_INDEX_URL,
Self::Rocm401 => &ROCM401_INDEX_URL,
Self::Xpu => &XPU_INDEX_URL,
}
}
@ -465,6 +470,7 @@ impl TorchBackend {
TorchBackend::Rocm42 => None,
TorchBackend::Rocm41 => None,
TorchBackend::Rocm401 => None,
TorchBackend::Xpu => None,
}
}
@ -512,6 +518,7 @@ impl TorchBackend {
TorchBackend::Rocm42 => Some(Version::new([4, 2])),
TorchBackend::Rocm41 => Some(Version::new([4, 1])),
TorchBackend::Rocm401 => Some(Version::new([4, 0, 1])),
TorchBackend::Xpu => None,
}
}
}
@ -562,6 +569,7 @@ impl FromStr for TorchBackend {
"rocm4.2" => Ok(TorchBackend::Rocm42),
"rocm4.1" => Ok(TorchBackend::Rocm41),
"rocm4.0.1" => Ok(TorchBackend::Rocm401),
"xpu" => Ok(TorchBackend::Xpu),
_ => Err(format!("Unknown PyTorch backend: {s}")),
}
}
@ -725,3 +733,5 @@ static ROCM41_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.1").unwrap());
static ROCM401_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm4.0.1").unwrap());
static XPU_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/xpu").unwrap());