Selects the appropriate wheel variant for Python packages based on hardware.
PyTorch and related packages (torchvision, torchaudio) have different wheel variants for different hardware configurations:
cpu- CPU-only buildcu118- CUDA 11.8cu121- CUDA 12.1cu124- CUDA 12.4rocm5.7- AMD ROCm 5.7
This module detects the current hardware and selects the appropriate variant.
Examples
# Get the PyTorch variant for current hardware
variant = SnakeBridge.WheelSelector.pytorch_variant()
#=> "cu121" or "cpu"
# Get the index URL for pip
url = SnakeBridge.WheelSelector.pytorch_index_url()
#=> "https://download.pytorch.org/whl/cu121"
# Generate pip install command
cmd = SnakeBridge.WheelSelector.pip_install_command("torch", "2.1.0")
#=> "pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121"
Summary
Functions
Returns all available PyTorch variants for the given CUDA versions.
Returns the best matching CUDA variant for a given CUDA version.
Normalizes a CUDA version string for wheel naming.
Generates a pip install command for a package.
Returns the PyTorch index URL for pip based on current hardware.
Checks if a package is a PyTorch package that needs hardware-specific wheels.
Returns the PyTorch wheel variant for the current hardware.
Selects the appropriate wheel for a package based on current hardware.
Types
Functions
@spec available_variants() :: [String.t()]
Returns all available PyTorch variants for the given CUDA versions.
Useful for generating lock files that support multiple hardware configurations.
Returns the best matching CUDA variant for a given CUDA version.
Falls back to the closest available version.
Examples
SnakeBridge.WheelSelector.best_cuda_variant("12.3")
#=> "cu124"
SnakeBridge.WheelSelector.best_cuda_variant("12.1")
#=> "cu121"
Normalizes a CUDA version string for wheel naming.
Examples
SnakeBridge.WheelSelector.normalize_cuda_version("12.1")
#=> "121"
SnakeBridge.WheelSelector.normalize_cuda_version("11.8")
#=> "118"
Generates a pip install command for a package.
For PyTorch packages (torch, torchvision, torchaudio), includes the appropriate --index-url for hardware-specific wheels.
Examples
SnakeBridge.WheelSelector.pip_install_command("torch", "2.1.0")
#=> "pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121"
SnakeBridge.WheelSelector.pip_install_command("numpy", "1.26.4")
#=> "pip install numpy==1.26.4"
@spec pytorch_index_url() :: String.t()
Returns the PyTorch index URL for pip based on current hardware.
Examples
SnakeBridge.WheelSelector.pytorch_index_url()
#=> "https://download.pytorch.org/whl/cu121"
Checks if a package is a PyTorch package that needs hardware-specific wheels.
@spec pytorch_variant() :: String.t()
Returns the PyTorch wheel variant for the current hardware.
Examples
SnakeBridge.WheelSelector.pytorch_variant()
#=> "cu121" # On CUDA 12.1 system
#=> "cpu" # On CPU-only system
@spec select_wheel(String.t(), String.t()) :: wheel_info()
Selects the appropriate wheel for a package based on current hardware.
Returns wheel info including variant and index URL if applicable.
Examples
SnakeBridge.WheelSelector.select_wheel("torch", "2.1.0")
#=> %{package: "torch", version: "2.1.0", variant: "cu121", index_url: "..."}
SnakeBridge.WheelSelector.select_wheel("numpy", "1.26.4")
#=> %{package: "numpy", version: "1.26.4", variant: nil, index_url: nil}