SnakeBridge.WheelSelector (SnakeBridge v0.15.0)

Copy Markdown View Source

Selects the appropriate wheel variant for Python packages based on hardware.

PyTorch and related packages (torchvision, torchaudio) have different wheel variants for different hardware configurations:

  • cpu - CPU-only build
  • cu118 - CUDA 11.8
  • cu121 - CUDA 12.1
  • cu124 - CUDA 12.4
  • rocm5.7 - AMD ROCm 5.7

This module detects the current hardware and selects the appropriate variant.

Examples

# Get the PyTorch variant for current hardware
variant = SnakeBridge.WheelSelector.pytorch_variant()
#=> "cu121" or "cpu"

# Get the index URL for pip
url = SnakeBridge.WheelSelector.pytorch_index_url()
#=> "https://download.pytorch.org/whl/cu121"

# Generate pip install command
cmd = SnakeBridge.WheelSelector.pip_install_command("torch", "2.1.0")
#=> "pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121"

Summary

Functions

Returns all available PyTorch variants for the given CUDA versions.

Returns the best matching CUDA variant for a given CUDA version.

Normalizes a CUDA version string for wheel naming.

Generates a pip install command for a package.

Returns the PyTorch index URL for pip based on current hardware.

Checks if a package is a PyTorch package that needs hardware-specific wheels.

Returns the PyTorch wheel variant for the current hardware.

Selects the appropriate wheel for a package based on current hardware.

Types

wheel_info()

@type wheel_info() :: %{
  package: String.t(),
  version: String.t(),
  variant: String.t() | nil,
  index_url: String.t() | nil
}

Functions

available_variants()

@spec available_variants() :: [String.t()]

Returns all available PyTorch variants for the given CUDA versions.

Useful for generating lock files that support multiple hardware configurations.

available_variants(package)

@spec available_variants(String.t()) :: [String.t()]

best_cuda_variant(cuda_version)

@spec best_cuda_variant(String.t() | nil) :: String.t()

Returns the best matching CUDA variant for a given CUDA version.

Falls back to the closest available version.

Examples

SnakeBridge.WheelSelector.best_cuda_variant("12.3")
#=> "cu124"

SnakeBridge.WheelSelector.best_cuda_variant("12.1")
#=> "cu121"

normalize_cuda_version(version)

@spec normalize_cuda_version(String.t() | nil) :: String.t() | nil

Normalizes a CUDA version string for wheel naming.

Examples

SnakeBridge.WheelSelector.normalize_cuda_version("12.1")
#=> "121"

SnakeBridge.WheelSelector.normalize_cuda_version("11.8")
#=> "118"

pip_install_command(package, version)

@spec pip_install_command(String.t(), String.t()) :: String.t()

Generates a pip install command for a package.

For PyTorch packages (torch, torchvision, torchaudio), includes the appropriate --index-url for hardware-specific wheels.

Examples

SnakeBridge.WheelSelector.pip_install_command("torch", "2.1.0")
#=> "pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121"

SnakeBridge.WheelSelector.pip_install_command("numpy", "1.26.4")
#=> "pip install numpy==1.26.4"

pytorch_index_url()

@spec pytorch_index_url() :: String.t()

Returns the PyTorch index URL for pip based on current hardware.

Examples

SnakeBridge.WheelSelector.pytorch_index_url()
#=> "https://download.pytorch.org/whl/cu121"

pytorch_package?(package)

@spec pytorch_package?(String.t()) :: boolean()

Checks if a package is a PyTorch package that needs hardware-specific wheels.

pytorch_variant()

@spec pytorch_variant() :: String.t()

Returns the PyTorch wheel variant for the current hardware.

Examples

SnakeBridge.WheelSelector.pytorch_variant()
#=> "cu121"  # On CUDA 12.1 system
#=> "cpu"    # On CPU-only system

select_wheel(package, version)

@spec select_wheel(String.t(), String.t()) :: wheel_info()

Selects the appropriate wheel for a package based on current hardware.

Returns wheel info including variant and index URL if applicable.

Examples

SnakeBridge.WheelSelector.select_wheel("torch", "2.1.0")
#=> %{package: "torch", version: "2.1.0", variant: "cu121", index_url: "..."}

SnakeBridge.WheelSelector.select_wheel("numpy", "1.26.4")
#=> %{package: "numpy", version: "1.26.4", variant: nil, index_url: nil}