SnakeBridge.WheelSelector (SnakeBridge v0.11.0)

View Source

Selects the appropriate wheel variant for Python packages based on hardware.

PyTorch and related packages (torchvision, torchaudio) have different wheel variants for different hardware configurations:

  • cpu - CPU-only build
  • cu118 - CUDA 11.8
  • cu121 - CUDA 12.1
  • cu124 - CUDA 12.4
  • rocm5.7 - AMD ROCm 5.7

This module detects the current hardware and selects the appropriate variant.

Examples

# Get the PyTorch variant for current hardware
variant = SnakeBridge.WheelSelector.pytorch_variant()
#=> "cu121" or "cpu"

# Get the index URL for pip
url = SnakeBridge.WheelSelector.pytorch_index_url()
#=> "https://download.pytorch.org/whl/cu121"

# Generate pip install command
cmd = SnakeBridge.WheelSelector.pip_install_command("torch", "2.1.0")
#=> "pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121"

Summary

Functions

Returns all available PyTorch variants for the given CUDA versions.

Returns the best matching CUDA variant for a given CUDA version.

Normalizes a CUDA version string for wheel naming.

Generates a pip install command for a package.

Returns the PyTorch index URL for pip based on current hardware.

Checks if a package is a PyTorch package that needs hardware-specific wheels.

Returns the PyTorch wheel variant for the current hardware.

Selects the appropriate wheel for a package based on current hardware.

Types

wheel_info()

@type wheel_info() :: %{
  package: String.t(),
  version: String.t(),
  variant: String.t() | nil,
  index_url: String.t() | nil
}

Functions

available_variants()

@spec available_variants() :: [String.t()]

Returns all available PyTorch variants for the given CUDA versions.

Useful for generating lock files that support multiple hardware configurations.

available_variants(package)

@spec available_variants(String.t()) :: [String.t()]

best_cuda_variant(cuda_version)

@spec best_cuda_variant(String.t() | nil) :: String.t()

Returns the best matching CUDA variant for a given CUDA version.

Falls back to the closest available version.

Examples

SnakeBridge.WheelSelector.best_cuda_variant("12.3")
#=> "cu124"

SnakeBridge.WheelSelector.best_cuda_variant("12.1")
#=> "cu121"

normalize_cuda_version(version)

@spec normalize_cuda_version(String.t() | nil) :: String.t() | nil

Normalizes a CUDA version string for wheel naming.

Examples

SnakeBridge.WheelSelector.normalize_cuda_version("12.1")
#=> "121"

SnakeBridge.WheelSelector.normalize_cuda_version("11.8")
#=> "118"

pip_install_command(package, version)

@spec pip_install_command(String.t(), String.t()) :: String.t()

Generates a pip install command for a package.

For PyTorch packages (torch, torchvision, torchaudio), includes the appropriate --index-url for hardware-specific wheels.

Examples

SnakeBridge.WheelSelector.pip_install_command("torch", "2.1.0")
#=> "pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121"

SnakeBridge.WheelSelector.pip_install_command("numpy", "1.26.4")
#=> "pip install numpy==1.26.4"

pytorch_index_url()

@spec pytorch_index_url() :: String.t()

Returns the PyTorch index URL for pip based on current hardware.

Examples

SnakeBridge.WheelSelector.pytorch_index_url()
#=> "https://download.pytorch.org/whl/cu121"

pytorch_package?(package)

@spec pytorch_package?(String.t()) :: boolean()

Checks if a package is a PyTorch package that needs hardware-specific wheels.

pytorch_variant()

@spec pytorch_variant() :: String.t()

Returns the PyTorch wheel variant for the current hardware.

Examples

SnakeBridge.WheelSelector.pytorch_variant()
#=> "cu121"  # On CUDA 12.1 system
#=> "cpu"    # On CPU-only system

select_wheel(package, version)

@spec select_wheel(String.t(), String.t()) :: wheel_info()

Selects the appropriate wheel for a package based on current hardware.

Returns wheel info including variant and index URL if applicable.

Examples

SnakeBridge.WheelSelector.select_wheel("torch", "2.1.0")
#=> %{package: "torch", version: "2.1.0", variant: "cu121", index_url: "..."}

SnakeBridge.WheelSelector.select_wheel("numpy", "1.26.4")
#=> %{package: "numpy", version: "1.26.4", variant: nil, index_url: nil}