mirror of
https://github.com/containers/ramalama.git
synced 2026-02-05 06:46:39 +01:00
775 lines
24 KiB
Python
775 lines
24 KiB
Python
from __future__ import annotations
|
|
|
|
import glob
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import platform
|
|
import random
|
|
import re
|
|
import shutil
|
|
import string
|
|
import subprocess
|
|
import sys
|
|
from collections.abc import Callable, Sequence
|
|
from functools import lru_cache
|
|
from typing import IO, TYPE_CHECKING, Any, Literal, Optional, Protocol, TypeAlias, TypedDict, cast, get_args
|
|
|
|
import yaml
|
|
|
|
import ramalama.amdkfd as amdkfd
|
|
from ramalama.logger import logger
|
|
from ramalama.version import version
|
|
|
|
if TYPE_CHECKING:
|
|
from argparse import Namespace
|
|
|
|
from ramalama.arg_types import SUPPORTED_ENGINES, ContainerArgType
|
|
from ramalama.config import Config, RamalamaImageConfig
|
|
from ramalama.transports.base import Transport
|
|
|
|
MNT_DIR = "/mnt/models"
|
|
MNT_FILE = f"{MNT_DIR}/model.file"
|
|
MNT_MMPROJ_FILE = f"{MNT_DIR}/mmproj.file"
|
|
MNT_FILE_DRAFT = f"{MNT_DIR}/draft_model.file"
|
|
MNT_CHAT_TEMPLATE_FILE = f"{MNT_DIR}/chat_template.file"
|
|
|
|
RAG_DIR = "/rag"
|
|
RAG_CONTENT = f"{MNT_DIR}/vector.db"
|
|
|
|
MIN_VRAM_BYTES = 1073741824 # 1GiB
|
|
|
|
SPLIT_MODEL_PATH_RE = r'(.*?)(?:/)?([^/]*)-00001-of-(\d{5})\.gguf'
|
|
|
|
|
|
def is_split_file_model(model_path):
|
|
"""returns true if ends with -%05d-of-%05d.gguf"""
|
|
return bool(re.match(SPLIT_MODEL_PATH_RE, model_path))
|
|
|
|
|
|
def sanitize_filename(filename: str) -> str:
|
|
return filename.replace(":", "-")
|
|
|
|
|
|
podman_machine_accel = False
|
|
|
|
|
|
def confirm_no_gpu(name, provider) -> bool:
|
|
while True:
|
|
user_input = (
|
|
input(
|
|
f"Warning! Your VM {name} is using {provider}, which does not support GPU. "
|
|
"Only the provider libkrun has GPU support. "
|
|
"See `man ramalama-macos` for more information. "
|
|
"Do you want to proceed without GPU? (yes/no): "
|
|
)
|
|
.strip()
|
|
.lower()
|
|
)
|
|
if user_input in ["yes", "y"]:
|
|
return True
|
|
if user_input in ["no", "n"]:
|
|
return False
|
|
print("Invalid input. Please enter 'yes' or 'no'.")
|
|
|
|
|
|
def handle_provider(machine, config: Config | None = None) -> bool | None:
|
|
global podman_machine_accel
|
|
name = machine.get("Name")
|
|
provider = machine.get("VMType")
|
|
running = machine.get("Running")
|
|
if running:
|
|
if provider == "applehv":
|
|
if config is not None and config.user.no_missing_gpu_prompt:
|
|
return True
|
|
else:
|
|
return confirm_no_gpu(name, provider)
|
|
if "krun" in provider:
|
|
podman_machine_accel = True
|
|
return True
|
|
|
|
return None
|
|
|
|
|
|
def apple_vm(engine: SUPPORTED_ENGINES, config: Config | None = None) -> bool:
|
|
podman_machine_list = [engine, "machine", "list", "--format", "json", "--all-providers"]
|
|
try:
|
|
machines_json = run_cmd(podman_machine_list, ignore_stderr=True, encoding="utf-8").stdout.strip()
|
|
machines = json.loads(machines_json)
|
|
for machine in machines:
|
|
result = handle_provider(machine, config)
|
|
if result is not None:
|
|
return result
|
|
except (subprocess.CalledProcessError, json.JSONDecodeError) as e:
|
|
logger.warning(f"Failed to list and parse podman machines: {e}")
|
|
return False
|
|
|
|
|
|
def perror(*args, **kwargs):
|
|
print(*args, file=sys.stderr, **kwargs)
|
|
|
|
|
|
def available(cmd: str) -> bool:
|
|
return shutil.which(cmd) is not None
|
|
|
|
|
|
def quoted(arr) -> str:
|
|
"""Return string with quotes around elements containing spaces."""
|
|
return " ".join(['"' + element + '"' if ' ' in element else element for element in arr])
|
|
|
|
|
|
def exec_cmd(args, stdout2null: bool = False, stderr2null: bool = False):
|
|
logger.debug(f"exec_cmd: {quoted(args)}")
|
|
|
|
stdout_target = subprocess.DEVNULL if stdout2null else None
|
|
stderr_target = subprocess.DEVNULL if stderr2null else None
|
|
try:
|
|
result = subprocess.run(args, stdout=stdout_target, stderr=stderr_target, check=False)
|
|
sys.exit(result.returncode)
|
|
except Exception as e:
|
|
perror(f"Failed to execute {quoted(args)}: {e}")
|
|
raise
|
|
|
|
|
|
def run_cmd(
|
|
args: Sequence[str],
|
|
cwd: str | None = None,
|
|
stdout: int | IO[Any] | None = subprocess.PIPE,
|
|
ignore_stderr: bool = False,
|
|
ignore_all: bool = False,
|
|
encoding: str | None = None,
|
|
env: dict[str, str] | None = None,
|
|
) -> subprocess.CompletedProcess[Any]:
|
|
"""
|
|
Run the given command arguments.
|
|
|
|
Args:
|
|
args: command line arguments to execute in a subprocess
|
|
cwd: optional working directory to run the command from
|
|
stdout: standard output configuration
|
|
ignore_stderr: if True, ignore standard error
|
|
ignore_all: if True, ignore both standard output and standard error
|
|
encoding: encoding to apply to the result text
|
|
"""
|
|
logger.debug(f"run_cmd: {quoted(args)}")
|
|
logger.debug(f"Working directory: {cwd}")
|
|
logger.debug(f"Ignore stderr: {ignore_stderr}")
|
|
logger.debug(f"Ignore all: {ignore_all}")
|
|
logger.debug(f"env: {env}")
|
|
|
|
serr = None
|
|
if ignore_all or ignore_stderr:
|
|
serr = subprocess.DEVNULL
|
|
|
|
sout = stdout
|
|
if ignore_all:
|
|
sout = subprocess.DEVNULL
|
|
|
|
if env:
|
|
env = os.environ | env
|
|
|
|
result = subprocess.run(args, check=True, cwd=cwd, stdout=sout, stderr=serr, encoding=encoding, env=env)
|
|
logger.debug(f"Command finished with return code: {result.returncode}")
|
|
|
|
return result
|
|
|
|
|
|
def populate_volume_from_image(model: Transport, args: Namespace, output_filename: str, src_model_dir: str = "models"):
|
|
"""Builds a Docker-compatible mount string that mirrors Podman image mounts for model assets.
|
|
|
|
This function requires the model
|
|
"""
|
|
|
|
vol_hash = hashlib.sha256(model.model.encode()).hexdigest()[:12]
|
|
volume = f"ramalama-models-{vol_hash}"
|
|
src = f"src-{vol_hash}"
|
|
|
|
# Ensure volume exists
|
|
run_cmd([args.engine, "volume", "create", volume], ignore_stderr=True)
|
|
|
|
# Fresh source container to export from
|
|
run_cmd([args.engine, "rm", "-f", src], ignore_stderr=True)
|
|
run_cmd([args.engine, "create", "--name", src, model.model])
|
|
|
|
try:
|
|
# Stream whole rootfs -> extract only models/<basename>
|
|
export_cmd = [args.engine, "export", src]
|
|
untar_cmd = [
|
|
args.engine,
|
|
"run",
|
|
"--rm",
|
|
"-i",
|
|
"--mount",
|
|
f"type=volume,src={volume},dst=/mnt",
|
|
"busybox",
|
|
"tar",
|
|
"-C",
|
|
"/mnt",
|
|
"--strip-components=1",
|
|
"-x",
|
|
"-p",
|
|
"-f",
|
|
"-",
|
|
f"{src_model_dir}/{output_filename}", # NOTE: double check this
|
|
]
|
|
|
|
with (
|
|
subprocess.Popen(export_cmd, stdout=subprocess.PIPE) as p_out,
|
|
subprocess.Popen(untar_cmd, stdin=p_out.stdout) as p_in,
|
|
):
|
|
p_out.stdout.close() # type: ignore
|
|
rc_in = p_in.wait()
|
|
rc_out = p_out.wait()
|
|
if rc_in != 0 or rc_out != 0:
|
|
raise subprocess.CalledProcessError(rc_in or rc_out, untar_cmd if rc_in else export_cmd)
|
|
finally:
|
|
run_cmd([args.engine, "rm", "-f", src], ignore_stderr=True)
|
|
|
|
return volume
|
|
|
|
|
|
def generate_sha256_binary(to_hash: bytes, with_sha_prefix: bool = True) -> str:
|
|
"""
|
|
Generates a sha256 for data bytes.
|
|
|
|
Args:
|
|
to_hash (bytes): The data to generate the sha256 hash for.
|
|
|
|
Returns:
|
|
str: Hex digest of the input appended to the prefix sha256-
|
|
"""
|
|
h = hashlib.new("sha256")
|
|
h.update(to_hash)
|
|
if with_sha_prefix:
|
|
return f"sha256-{h.hexdigest()}"
|
|
return h.hexdigest()
|
|
|
|
|
|
def generate_sha256(to_hash: str, with_sha_prefix: bool = True) -> str:
|
|
return generate_sha256_binary(to_hash.encode("utf-8"), with_sha_prefix)
|
|
|
|
|
|
def verify_checksum(filename: str) -> bool:
|
|
"""
|
|
Verifies if the SHA-256 checksum of a file matches the checksum provided in
|
|
the filename.
|
|
|
|
Args:
|
|
filename (str): The filename containing the checksum prefix
|
|
(e.g., "sha256:<checksum>")
|
|
|
|
Returns:
|
|
bool: True if the checksum matches, False otherwise.
|
|
"""
|
|
|
|
if not os.path.exists(filename):
|
|
return False
|
|
|
|
# Check if the filename starts with "sha256:" or "sha256-" and extract the checksum from filename
|
|
expected_checksum = ""
|
|
fn_base = os.path.basename(filename)
|
|
if fn_base.startswith("sha256:"):
|
|
expected_checksum = fn_base.split(":")[1]
|
|
elif fn_base.startswith("sha256-"):
|
|
expected_checksum = fn_base.split("-")[1]
|
|
else:
|
|
raise ValueError(f"filename has to start with 'sha256:' or 'sha256-': {fn_base}")
|
|
|
|
if len(expected_checksum) != 64:
|
|
raise ValueError("invalid checksum length in filename")
|
|
|
|
# Calculate the SHA-256 checksum of the file contents
|
|
sha256_hash = hashlib.sha256()
|
|
with open(filename, "rb") as f:
|
|
for byte_block in iter(lambda: f.read(4096), b""):
|
|
sha256_hash.update(byte_block)
|
|
|
|
# Compare the checksums
|
|
return sha256_hash.hexdigest() == expected_checksum
|
|
|
|
|
|
def genname():
|
|
return "ramalama-" + "".join(random.choices(string.ascii_letters + string.digits, k=10))
|
|
|
|
|
|
def engine_version(engine: SUPPORTED_ENGINES) -> str:
|
|
# Create manifest list for target with imageid
|
|
cmd_args = [str(engine), "version", "--format", "{{ .Client.Version }}"]
|
|
return run_cmd(cmd_args, encoding="utf-8").stdout.strip()
|
|
|
|
|
|
class CDI_DEVICE(TypedDict):
|
|
name: str
|
|
|
|
|
|
class CDI_RETURN_TYPE(TypedDict):
|
|
devices: list[CDI_DEVICE]
|
|
|
|
|
|
def load_cdi_config(spec_dirs: list[str]) -> CDI_RETURN_TYPE | None:
|
|
# Loads the first YAML or JSON CDI configuration file found in the
|
|
# given directories."""
|
|
|
|
for spec_dir in spec_dirs:
|
|
for root, _, files in os.walk(spec_dir):
|
|
for file in files:
|
|
_, ext = os.path.splitext(file)
|
|
file_path = os.path.join(root, file)
|
|
if ext in [".yaml", ".yml"]:
|
|
try:
|
|
with open(file_path, "r") as stream:
|
|
return yaml.safe_load(stream)
|
|
except (OSError, yaml.YAMLError) as e:
|
|
logger.warning(f"Failed to load YAML file {file_path}: {e}")
|
|
continue
|
|
elif ext == ".json":
|
|
try:
|
|
with open(file_path, "r") as stream:
|
|
return json.load(stream)
|
|
except (OSError, json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
logger.warning(f"Failed to load JSON file {file_path}: {e}")
|
|
continue
|
|
return None
|
|
|
|
|
|
def get_podman_machine_cdi_config() -> CDI_RETURN_TYPE | None:
|
|
cdi_config = run_cmd(["podman", "machine", "ssh", "cat", "/etc/cdi/nvidia.yaml"], encoding="utf-8").stdout.strip()
|
|
if cdi_config:
|
|
return yaml.safe_load(cdi_config)
|
|
return None
|
|
|
|
|
|
def find_in_cdi(devices: list[str]) -> tuple[list[str], list[str]]:
|
|
# Attempts to find a CDI configuration for each device in devices
|
|
# and returns a list of configured devices and a list of
|
|
# unconfigured devices.
|
|
if platform.system() == "Windows":
|
|
cdi = get_podman_machine_cdi_config()
|
|
else:
|
|
cdi = load_cdi_config(['/var/run/cdi', '/etc/cdi'])
|
|
try:
|
|
cdi_devices = cdi.get("devices", []) if cdi else []
|
|
cdi_device_names = [name for cdi_device in cdi_devices if (name := cdi_device.get("name"))]
|
|
except (AttributeError, KeyError, TypeError) as e:
|
|
# Malformed YAML or JSON. Treat everything as unconfigured but warn.
|
|
logger.warning(f"Unable to process CDI configuration: {e}")
|
|
return ([], devices)
|
|
|
|
configured = []
|
|
unconfigured = []
|
|
for device in devices:
|
|
if device in cdi_device_names:
|
|
configured.append(device)
|
|
# A device can be specified by a prefix of the uuid
|
|
elif device.startswith("GPU") and any(name.startswith(device) for name in cdi_device_names):
|
|
configured.append(device)
|
|
else:
|
|
perror(f"Device {device} does not have a CDI configuration")
|
|
unconfigured.append(device)
|
|
|
|
return configured, unconfigured
|
|
|
|
|
|
def check_asahi() -> Literal["asahi"] | None:
|
|
if os.path.exists('/proc/device-tree/compatible'):
|
|
try:
|
|
with open('/proc/device-tree/compatible', 'rb') as f:
|
|
content = f.read().split(b"\0")
|
|
if b"apple,arm-platform" in content:
|
|
os.environ["ASAHI_VISIBLE_DEVICES"] = "1"
|
|
return "asahi"
|
|
except OSError:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def check_metal(args: ContainerArgType) -> bool:
|
|
if args.container:
|
|
return False
|
|
return platform.system() == "Darwin"
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def check_nvidia() -> Literal["cuda"] | None:
|
|
try:
|
|
command = ['nvidia-smi', '--query-gpu=index,uuid', '--format=csv,noheader']
|
|
result = run_cmd(command, encoding="utf-8")
|
|
except (OSError, subprocess.CalledProcessError):
|
|
return None
|
|
|
|
smi_lines = result.stdout.splitlines()
|
|
parsed_lines: list[list[str]] = [[item.strip() for item in line.split(',')] for line in smi_lines if line]
|
|
|
|
if not parsed_lines:
|
|
return None
|
|
|
|
indices, uuids = map(list, zip(*parsed_lines))
|
|
# Get the list of devices specified by CUDA_VISIBLE_DEVICES, if any
|
|
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
|
visible_devices = cuda_visible_devices.split(',') if cuda_visible_devices else []
|
|
for device in visible_devices:
|
|
if device not in indices and not any(uuid.startswith(device) for uuid in uuids):
|
|
perror(f"{device} not found")
|
|
return None
|
|
|
|
configured, unconfigured = find_in_cdi(visible_devices + ["all"])
|
|
|
|
configured_has_all = "all" in configured
|
|
if unconfigured and not configured_has_all:
|
|
perror(f"No CDI configuration found for {','.join(unconfigured)}")
|
|
perror("You can use the \"nvidia-ctk cdi generate\" command from the ")
|
|
perror("nvidia-container-toolkit to generate a CDI configuration.")
|
|
perror("See ramalama-cuda(7).")
|
|
return None
|
|
elif configured:
|
|
if configured_has_all:
|
|
configured.remove("all")
|
|
if not configured:
|
|
configured = indices
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(configured)
|
|
return "cuda"
|
|
|
|
return None
|
|
|
|
|
|
def check_ascend() -> Literal["cann"] | None:
|
|
try:
|
|
command = ['npu-smi', 'info']
|
|
run_cmd(command, encoding="utf-8")
|
|
os.environ["ASCEND_VISIBLE_DEVICES"] = "0"
|
|
return "cann"
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def check_rocm_amd() -> Literal["hip"] | None:
|
|
if is_arm():
|
|
# ROCm is not available for arm64, use Vulkan instead
|
|
return None
|
|
|
|
gpu_num = 0
|
|
gpu_bytes = 0
|
|
for i, (np, props) in enumerate(amdkfd.gpus()):
|
|
# Radeon GPUs older than gfx900 are not supported by ROCm (e.g. Polaris)
|
|
if props['gfx_target_version'] < 90000:
|
|
continue
|
|
|
|
mem_banks_count = int(props['mem_banks_count'])
|
|
mem_bytes = 0
|
|
for bank in range(mem_banks_count):
|
|
bank_props = amdkfd.parse_props(np + f'/mem_banks/{bank}/properties')
|
|
# See /usr/include/linux/kfd_sysfs.h for possible heap types
|
|
#
|
|
# Count public and private framebuffer memory as VRAM
|
|
if bank_props['heap_type'] in [amdkfd.HEAP_TYPE_FB_PUBLIC, amdkfd.HEAP_TYPE_FB_PRIVATE]:
|
|
mem_bytes += int(bank_props['size_in_bytes'])
|
|
|
|
if mem_bytes > MIN_VRAM_BYTES and mem_bytes > gpu_bytes:
|
|
gpu_bytes = mem_bytes
|
|
gpu_num = i
|
|
|
|
if gpu_bytes:
|
|
os.environ["HIP_VISIBLE_DEVICES"] = str(gpu_num)
|
|
return "hip"
|
|
|
|
return None
|
|
|
|
|
|
def is_arm() -> bool:
|
|
return platform.machine() in ('arm64', 'aarch64')
|
|
|
|
|
|
def check_intel() -> Literal["intel"] | None:
|
|
igpu_num = 0
|
|
# Device IDs for select Intel GPUs. See: https://dgpu-docs.intel.com/devices/hardware-table.html
|
|
intel_gpus = (
|
|
b"0xe20b",
|
|
b"0xe20c",
|
|
b"0x46a6",
|
|
b"0x46a8",
|
|
b"0x46aa",
|
|
b"0x56a0",
|
|
b"0x56a1",
|
|
b"0x7d51",
|
|
b"0x7dd5",
|
|
b"0x7d55",
|
|
)
|
|
intel_driver_glob_patterns = ["/sys/bus/pci/drivers/i915/*/device", "/sys/bus/pci/drivers/xe/*/device"]
|
|
# Check to see if any of the device ids in intel_gpus are in the device id of the i915 / xe driver
|
|
for fp in sorted([i for p in intel_driver_glob_patterns for i in glob.glob(p)]):
|
|
with open(fp, 'rb') as file:
|
|
content = file.read()
|
|
for gpu_id in intel_gpus:
|
|
if gpu_id in content:
|
|
igpu_num += 1
|
|
if igpu_num:
|
|
os.environ["INTEL_VISIBLE_DEVICES"] = str(igpu_num)
|
|
return "intel"
|
|
|
|
return None
|
|
|
|
|
|
def check_mthreads() -> Literal["musa"] | None:
|
|
try:
|
|
command = ['mthreads-gmi']
|
|
run_cmd(command, encoding="utf-8")
|
|
os.environ["MUSA_VISIBLE_DEVICES"] = "0"
|
|
return "musa"
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
AccelType: TypeAlias = Literal["asahi", "cuda", "cann", "hip", "intel", "musa"]
|
|
|
|
|
|
def get_accel() -> AccelType | Literal["none"]:
|
|
checks: tuple[Callable[[], AccelType | None], ...] = (
|
|
check_asahi,
|
|
cast(Callable[[], Literal['cuda'] | None], check_nvidia),
|
|
check_ascend,
|
|
check_rocm_amd,
|
|
check_intel,
|
|
check_mthreads,
|
|
)
|
|
for check in checks:
|
|
if result := check():
|
|
return result
|
|
return "none"
|
|
|
|
|
|
def set_accel_env_vars():
|
|
if get_accel_env_vars():
|
|
return
|
|
|
|
get_accel()
|
|
|
|
|
|
def set_gpu_type_env_vars():
|
|
if get_gpu_type_env_vars():
|
|
return
|
|
|
|
get_accel()
|
|
|
|
|
|
GPUEnvVar: TypeAlias = Literal[
|
|
"ASAHI_VISIBLE_DEVICES",
|
|
"ASCEND_VISIBLE_DEVICES",
|
|
"CUDA_VISIBLE_DEVICES",
|
|
"GGML_VK_VISIBLE_DEVICES",
|
|
"HIP_VISIBLE_DEVICES",
|
|
"INTEL_VISIBLE_DEVICES",
|
|
"MUSA_VISIBLE_DEVICES",
|
|
]
|
|
|
|
|
|
def get_gpu_type_env_vars() -> dict[GPUEnvVar, str]:
|
|
return {k: v for k in get_args(GPUEnvVar) if (v := os.environ.get(k))}
|
|
|
|
|
|
AccelEnvVar: TypeAlias = Literal[
|
|
"CUDA_LAUNCH_BLOCKING",
|
|
"HSA_VISIBLE_DEVICES",
|
|
"HSA_OVERRIDE_GFX_VERSION",
|
|
]
|
|
|
|
|
|
def get_accel_env_vars() -> dict[GPUEnvVar | AccelEnvVar, str]:
|
|
gpu_env_vars: dict[GPUEnvVar, str] = get_gpu_type_env_vars()
|
|
accel_env_vars: dict[AccelEnvVar, str] = {k: v for k in get_args(AccelEnvVar) if (v := os.environ.get(k))}
|
|
return gpu_env_vars | accel_env_vars
|
|
|
|
|
|
def rm_until_substring(input: str, substring: str) -> str:
|
|
pos = input.find(substring)
|
|
if pos == -1:
|
|
return input
|
|
return input[pos + len(substring) :]
|
|
|
|
|
|
def minor_release() -> str:
|
|
version_split = version().split(".")
|
|
vers = ".".join(version_split[:2])
|
|
if vers == "0":
|
|
vers = "latest"
|
|
return vers
|
|
|
|
|
|
def tagged_image(image: str) -> str:
|
|
if len(image.split(":")) > 1:
|
|
return image
|
|
return f"{image}:{minor_release()}"
|
|
|
|
|
|
def check_cuda_version() -> tuple[int, int]:
|
|
"""
|
|
Check the CUDA version installed on the system by parsing the output of nvidia-smi --version.
|
|
|
|
Returns:
|
|
tuple: A tuple of (major, minor) version numbers, or (0, 0) if CUDA is not found or version can't be determined.
|
|
"""
|
|
try:
|
|
# Run nvidia-smi --version to get version info
|
|
command = ['nvidia-smi']
|
|
output = run_cmd(command, encoding="utf-8").stdout.strip()
|
|
|
|
# Look for CUDA Version in the output
|
|
cuda_match = re.search(r'CUDA Version\s*:\s*(\d+)\.(\d+)', output)
|
|
if cuda_match:
|
|
major = int(cuda_match.group(1))
|
|
minor = int(cuda_match.group(2))
|
|
return (major, minor)
|
|
except Exception:
|
|
pass
|
|
|
|
return (0, 0)
|
|
|
|
|
|
def select_cuda_image(config: Config) -> str:
|
|
"""
|
|
Select appropriate CUDA image based on the detected CUDA version.
|
|
|
|
Args:
|
|
config: The configuration object containing the CUDA image reference
|
|
|
|
Returns:
|
|
str: The appropriate CUDA image name
|
|
|
|
Raises:
|
|
NotImplementedError: If CUDA version is less than 12.4
|
|
"""
|
|
# Get the default CUDA image from config
|
|
cuda_image = config.images.get("CUDA_VISIBLE_DEVICES")
|
|
|
|
if cuda_image is None:
|
|
raise NotImplementedError("No image repository found for CUDA_VISIBLE_DEVICES in config.")
|
|
|
|
# Check CUDA version and select appropriate image
|
|
cuda_version = check_cuda_version()
|
|
|
|
# Select appropriate image based on CUDA version
|
|
if cuda_version >= (12, 8):
|
|
return cuda_image # Use the standard image for CUDA 12.8+
|
|
elif cuda_version >= (12, 4):
|
|
return f"{cuda_image}-12.4.1" # Use the specific version for older CUDA
|
|
else:
|
|
raise NotImplementedError(f"CUDA version {cuda_version} is not supported. Minimum required version is 12.4.")
|
|
|
|
|
|
class AccelImageArgsWithImage(Protocol):
|
|
image: str
|
|
|
|
|
|
class AccelImageArgsVLLMRuntime(Protocol):
|
|
runtime: Literal["vllm"]
|
|
|
|
|
|
class AccelImageArgsOtherRuntime(Protocol):
|
|
runtime: str
|
|
container: bool
|
|
quiet: bool
|
|
|
|
|
|
class AccelImageArgsOtherRuntimeRAG(Protocol):
|
|
rag: bool
|
|
runtime: str
|
|
container: bool
|
|
quiet: bool
|
|
|
|
|
|
AccelImageArgs: TypeAlias = (
|
|
None | AccelImageArgsVLLMRuntime | AccelImageArgsOtherRuntime | AccelImageArgsOtherRuntimeRAG
|
|
)
|
|
|
|
|
|
def accel_image(config: Config, images: RamalamaImageConfig | None = None, conf_key: str = "image") -> str:
|
|
"""
|
|
Selects and the appropriate image based on config, arguments, environment.
|
|
"images" is a mapping of environment variable names to image names. If not specified, the
|
|
mapping from default config will be used.
|
|
"conf_key" is the configuration key that holds the configured value of the selected image.
|
|
If not specified, it defaults to "image".
|
|
"""
|
|
# User provided an image via config
|
|
if config.is_set(conf_key):
|
|
return tagged_image(getattr(config, conf_key))
|
|
|
|
if not images:
|
|
images = config.images
|
|
|
|
set_gpu_type_env_vars()
|
|
gpu_type = next(iter(get_gpu_type_env_vars()), "")
|
|
|
|
if config.runtime == "vllm":
|
|
# Check for GPU-specific VLLM image, with a fallback to the generic one.
|
|
image = None
|
|
if gpu_type:
|
|
image = config.images.get(f"VLLM_{gpu_type}")
|
|
|
|
if not image:
|
|
image = config.images.get("VLLM", "docker.io/vllm/vllm-openai")
|
|
|
|
# If the image from the config is specified by tag or digest, return it unmodified
|
|
return image if ":" in image else f"{image}:latest"
|
|
# Get image based on detected GPU type
|
|
image = images.get(gpu_type, getattr(config, f"default_{conf_key}"))
|
|
|
|
# If the image from the config is specified by tag or digest, return it unmodified
|
|
if ":" in image:
|
|
return image
|
|
|
|
# Special handling for CUDA images based on version - only if the image is the default CUDA image
|
|
if conf_key == "image" and image == images.get("CUDA_VISIBLE_DEVICES"):
|
|
try:
|
|
image = select_cuda_image(config)
|
|
except NotImplementedError as e:
|
|
logger.warning(f"{e}: Falling back to default image.")
|
|
image = config.default_image
|
|
|
|
vers = minor_release()
|
|
|
|
should_pull = config.pull in ["always", "missing"] and not config.dryrun
|
|
if config.engine and attempt_to_use_versioned(config.engine, image, vers, True, should_pull):
|
|
return f"{image}:{vers}"
|
|
|
|
return f"{image}:latest"
|
|
|
|
|
|
def attempt_to_use_versioned(conman: str, image: str, vers: str, quiet: bool, should_pull: bool) -> bool:
|
|
try:
|
|
# check if versioned image exists locally
|
|
if run_cmd([conman, "inspect", f"{image}:{vers}"], ignore_all=True):
|
|
return True
|
|
|
|
except Exception:
|
|
pass
|
|
|
|
if not should_pull:
|
|
return False
|
|
|
|
try:
|
|
# attempt to pull the versioned image
|
|
if not quiet:
|
|
perror(f"Attempting to pull {image}:{vers} ...")
|
|
run_cmd([conman, "pull", f"{image}:{vers}"], ignore_stderr=True)
|
|
return True
|
|
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
class ContainerEntryPoint(str):
|
|
def __init__(self, entrypoint: Optional[str] = None):
|
|
self.entrypoint = entrypoint
|
|
|
|
def __str__(self):
|
|
return str(self.entrypoint)
|
|
|
|
def __repr__(self):
|
|
return repr(self.entrypoint)
|