mirror of
https://github.com/containers/ramalama.git
synced 2026-02-05 06:46:39 +01:00
Merge pull request #2239 from olliewalsh/windows_part_deux
Windows support part 2
This commit is contained in:
2
.flake8
2
.flake8
@@ -2,4 +2,4 @@
|
||||
max-line-length = 120
|
||||
# E203,E221,E231 conflict with black formatting
|
||||
extend-ignore = E203,E221,E231,E702,F824
|
||||
extend-exclude = .venv,venv
|
||||
extend-exclude = .venv,venv,build
|
||||
|
||||
27
.github/workflows/ci.yml
vendored
27
.github/workflows/ci.yml
vendored
@@ -524,17 +524,9 @@ jobs:
|
||||
run: |
|
||||
Write-Host "Installing Podman on Windows..."
|
||||
|
||||
# Download and install Podman
|
||||
$podmanVersion = "5.7.0"
|
||||
$installerUrl = "https://github.com/containers/podman/releases/download/v$podmanVersion/podman-$podmanVersion-setup.exe"
|
||||
$installerPath = "$env:TEMP\podman-setup.exe"
|
||||
|
||||
Write-Host "Downloading Podman v$podmanVersion..."
|
||||
Invoke-WebRequest -Uri $installerUrl -OutFile $installerPath
|
||||
|
||||
Write-Host "Installing Podman..."
|
||||
Start-Process -FilePath $installerPath -ArgumentList "/install", "/quiet", "/norestart" -Wait -NoNewWindow
|
||||
winget install --accept-source-agreements --silent --disable-interactivity --exact --id RedHat.Podman
|
||||
|
||||
# TODO: remove when this just works https://github.com/microsoft/winget-cli/issues/549
|
||||
# Add Podman to PATH for current session
|
||||
$podmanPath = "$env:ProgramFiles" + "\RedHat\Podman"
|
||||
$env:PATH = "$podmanPath;$env:PATH"
|
||||
@@ -573,6 +565,21 @@ jobs:
|
||||
Write-Host "Podman info:"
|
||||
podman info
|
||||
|
||||
- name: Install OpenSSL
|
||||
shell: pwsh
|
||||
run: |
|
||||
winget install --accept-source-agreements --silent --disable-interactivity --exact --id ShiningLight.OpenSSL.Light
|
||||
|
||||
# Add OpenSSL to PATH for current session
|
||||
$opensslPath = "$env:ProgramFiles" + "\OpenSSL-Win64\bin"
|
||||
$env:PATH = "$opensslPath;$env:PATH"
|
||||
[Environment]::SetEnvironmentVariable("PATH", $env:PATH, [EnvironmentVariableTarget]::Process)
|
||||
|
||||
# Update PATH for future steps
|
||||
Add-Content -Path $env:GITHUB_PATH -Value "$opensslPath"
|
||||
|
||||
openssl version
|
||||
|
||||
- name: Run E2E tests
|
||||
shell: pwsh
|
||||
env:
|
||||
|
||||
@@ -8,7 +8,7 @@ ENTRYPOINT ["/usr/bin/entrypoint.sh"]
|
||||
|
||||
RUN dnf -y --setopt=install_weak_deps=false install \
|
||||
make bats jq iproute podman openssl httpd-tools diffutils procps-ng \
|
||||
gcc \
|
||||
gcc cargo \
|
||||
$([ $(uname -m) == "x86_64" ] && echo ollama) \
|
||||
# for validate and unit-tests
|
||||
shellcheck \
|
||||
|
||||
@@ -58,6 +58,7 @@ text = "MIT"
|
||||
dev = [
|
||||
# "pytest>=7.0",
|
||||
"argcomplete~=3.0",
|
||||
"bcrypt",
|
||||
"black~=25.0",
|
||||
"codespell~=2.0",
|
||||
"flake8~=7.0",
|
||||
@@ -69,7 +70,7 @@ dev = [
|
||||
"mypy",
|
||||
"types-PyYAML",
|
||||
"types-jsonschema",
|
||||
"tox"
|
||||
"tox",
|
||||
]
|
||||
|
||||
cov = [
|
||||
|
||||
@@ -97,7 +97,6 @@ def add_api_key(args, headers=None):
|
||||
@dataclass
|
||||
class ChatOperationalArgs:
|
||||
initial_connection: bool = False
|
||||
pid2kill: int | None = None
|
||||
name: str | None = None
|
||||
keepalive: int | None = None
|
||||
monitor: "ServerMonitor | None" = None
|
||||
@@ -485,23 +484,12 @@ class RamaLamaShell(cmd.Cmd):
|
||||
if getattr(self.args, "initial_connection", False):
|
||||
return
|
||||
|
||||
if getattr(self.args, "pid2kill", False):
|
||||
# Send signals to terminate process
|
||||
# On Windows, only SIGTERM and SIGINT are supported
|
||||
if getattr(self.args, "server_process", False):
|
||||
self.args.server_process.terminate()
|
||||
try:
|
||||
os.kill(self.args.pid2kill, signal.SIGINT)
|
||||
except (ProcessLookupError, AttributeError):
|
||||
pass
|
||||
try:
|
||||
os.kill(self.args.pid2kill, signal.SIGTERM)
|
||||
except (ProcessLookupError, AttributeError):
|
||||
pass
|
||||
# SIGKILL doesn't exist on Windows, use SIGTERM instead
|
||||
if hasattr(signal, 'SIGKILL'):
|
||||
try:
|
||||
os.kill(self.args.pid2kill, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
self.args.server_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.args.server_process.kill()
|
||||
elif getattr(self.args, "name", None):
|
||||
args = copy.copy(self.args)
|
||||
args.ignore = True
|
||||
@@ -548,41 +536,40 @@ class ServerMonitor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_pid=None,
|
||||
server_process=None,
|
||||
container_name=None,
|
||||
container_engine=None,
|
||||
join_timeout=3.0,
|
||||
check_interval=0.5,
|
||||
inspect_timeout=30.0,
|
||||
):
|
||||
"""
|
||||
Initialize the server monitor.
|
||||
|
||||
Args:
|
||||
server_pid: Process ID to monitor (for direct process monitoring)
|
||||
server_process: subprocess.Popen object to monitor
|
||||
container_name: Container name to monitor (for container monitoring)
|
||||
container_engine: Container engine command (podman/docker)
|
||||
join_timeout: Seconds for thread join when stopping (default: 3.0)
|
||||
check_interval: Seconds between monitoring checks (default: 0.5)
|
||||
inspect_timeout: Seconds to wait for container inspect command to complete (default: 30.0)
|
||||
|
||||
Note: If neither server_pid nor container_name is provided, the monitor
|
||||
Note: If neither server_process nor container_name is provided, the monitor
|
||||
operates in no-op mode (no actual monitoring occurs).
|
||||
"""
|
||||
self.server_pid = server_pid
|
||||
self.server_process = server_process
|
||||
self.container_name = container_name
|
||||
self.container_engine = container_engine
|
||||
self.timeout = join_timeout
|
||||
self.check_interval = check_interval
|
||||
|
||||
self.inspect_timeout = inspect_timeout
|
||||
self._stop_event = threading.Event()
|
||||
self._exited_event = threading.Event()
|
||||
self._exit_info = {}
|
||||
self._monitor_thread = None
|
||||
|
||||
# Determine monitoring mode
|
||||
if server_pid:
|
||||
if sys.platform == "win32":
|
||||
# os.waitpid() is not available for non-child processes on Windows.
|
||||
raise NotImplementedError("Process monitoring by PID is not supported on Windows.")
|
||||
if self.server_process:
|
||||
self._mode = "process"
|
||||
elif container_name and container_engine:
|
||||
self._mode = "container"
|
||||
@@ -630,31 +617,24 @@ class ServerMonitor:
|
||||
"""Monitor the server process and report if it exits."""
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
# Use waitpid with WNOHANG to check without blocking
|
||||
pid, status = os.waitpid(self.server_pid, os.WNOHANG)
|
||||
if pid != 0:
|
||||
exit_code = self.server_process.poll()
|
||||
if exit_code is not None:
|
||||
# Process has exited
|
||||
self._exit_info["pid"] = self.server_pid
|
||||
self._exit_info["status"] = status
|
||||
if os.WIFEXITED(status):
|
||||
self._exit_info["type"] = "exit"
|
||||
self._exit_info["code"] = os.WEXITSTATUS(status)
|
||||
elif os.WIFSIGNALED(status):
|
||||
self._exit_info["type"] = "signal"
|
||||
self._exit_info["signal"] = os.WTERMSIG(status)
|
||||
else:
|
||||
self._exit_info["type"] = "unknown"
|
||||
self._exit_info["pid"] = self.server_process.pid
|
||||
self._exit_info["type"] = "exit"
|
||||
self._exit_info["code"] = exit_code
|
||||
self._exited_event.set()
|
||||
# Send SIGINT to main process to interrupt the chat
|
||||
_thread.interrupt_main()
|
||||
break
|
||||
except ChildProcessError:
|
||||
# Process doesn't exist or already reaped
|
||||
self._exit_info["pid"] = self.server_pid
|
||||
except Exception as e:
|
||||
logger.debug(f"Error monitoring process: {e}", exc_info=True)
|
||||
self._exit_info["pid"] = self.server_process.pid
|
||||
self._exit_info["type"] = "missing"
|
||||
self._exited_event.set()
|
||||
_thread.interrupt_main()
|
||||
break
|
||||
|
||||
# Use wait() instead of sleep() for responsive shutdown
|
||||
self._stop_event.wait(self.check_interval)
|
||||
|
||||
@@ -668,7 +648,7 @@ class ServerMonitor:
|
||||
[self.container_engine, "inspect", "--format", inspect_format, self.container_name],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.check_interval,
|
||||
timeout=self.inspect_timeout,
|
||||
)
|
||||
output_lines = result.stdout.strip().split('\n')
|
||||
status = output_lines[0] if output_lines else ""
|
||||
@@ -753,13 +733,12 @@ def chat(args: ChatArgsType, operational_args: ChatOperationalArgs | None = None
|
||||
signal.alarm(convert_to_seconds(args.keepalive)) # type: ignore
|
||||
|
||||
# Start server process or container monitoring
|
||||
# Check if we should monitor a process (pid2kill) or container (name)
|
||||
pid2kill = getattr(args, "pid2kill", None)
|
||||
server_process = getattr(args, "server_process", None)
|
||||
container_name = getattr(args, "name", None)
|
||||
|
||||
if pid2kill:
|
||||
if server_process:
|
||||
# Monitor the server process
|
||||
monitor = ServerMonitor(server_pid=pid2kill)
|
||||
monitor = ServerMonitor(server_process=server_process)
|
||||
elif container_name:
|
||||
# Monitor the container
|
||||
conman = getattr(args, "engine", CONFIG.engine)
|
||||
|
||||
@@ -122,18 +122,14 @@ def quoted(arr) -> str:
|
||||
|
||||
def exec_cmd(args, stdout2null: bool = False, stderr2null: bool = False):
|
||||
logger.debug(f"exec_cmd: {quoted(args)}")
|
||||
if stdout2null:
|
||||
with open(os.devnull, 'w') as devnull:
|
||||
os.dup2(devnull.fileno(), sys.stdout.fileno())
|
||||
|
||||
if stderr2null:
|
||||
with open(os.devnull, 'w') as devnull:
|
||||
os.dup2(devnull.fileno(), sys.stderr.fileno())
|
||||
|
||||
stdout_target = subprocess.DEVNULL if stdout2null else None
|
||||
stderr_target = subprocess.DEVNULL if stderr2null else None
|
||||
try:
|
||||
return os.execvp(args[0], args)
|
||||
except Exception:
|
||||
perror(f"os.execvp({args[0]}, {args})")
|
||||
result = subprocess.run(args, stdout=stdout_target, stderr=stderr_target, check=False)
|
||||
sys.exit(result.returncode)
|
||||
except Exception as e:
|
||||
perror(f"Failed to execute {quoted(args)}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -226,23 +222,27 @@ def populate_volume_from_image(model: Transport, args: Namespace, output_filenam
|
||||
return volume
|
||||
|
||||
|
||||
def generate_sha256(to_hash: str, with_sha_prefix: bool = True) -> str:
|
||||
def generate_sha256_binary(to_hash: bytes, with_sha_prefix: bool = True) -> str:
|
||||
"""
|
||||
Generates a sha256 for a string.
|
||||
Generates a sha256 for data bytes.
|
||||
|
||||
Args:
|
||||
to_hash (str): The string to generate the sha256 hash for.
|
||||
to_hash (bytes): The data to generate the sha256 hash for.
|
||||
|
||||
Returns:
|
||||
str: Hex digest of the input appended to the prefix sha256-
|
||||
"""
|
||||
h = hashlib.new("sha256")
|
||||
h.update(to_hash.encode("utf-8"))
|
||||
h.update(to_hash)
|
||||
if with_sha_prefix:
|
||||
return f"sha256-{h.hexdigest()}"
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def generate_sha256(to_hash: str, with_sha_prefix: bool = True) -> str:
|
||||
return generate_sha256_binary(to_hash.encode("utf-8"), with_sha_prefix)
|
||||
|
||||
|
||||
def verify_checksum(filename: str) -> bool:
|
||||
"""
|
||||
Verifies if the SHA-256 checksum of a file matches the checksum provided in
|
||||
@@ -326,11 +326,21 @@ def load_cdi_config(spec_dirs: list[str]) -> CDI_RETURN_TYPE | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_podman_machine_cdi_config() -> CDI_RETURN_TYPE | None:
|
||||
cdi_config = run_cmd(["podman", "machine", "ssh", "cat", "/etc/cdi/nvidia.yaml"], encoding="utf-8").stdout.strip()
|
||||
if cdi_config:
|
||||
return yaml.safe_load(cdi_config)
|
||||
return None
|
||||
|
||||
|
||||
def find_in_cdi(devices: list[str]) -> tuple[list[str], list[str]]:
|
||||
# Attempts to find a CDI configuration for each device in devices
|
||||
# and returns a list of configured devices and a list of
|
||||
# unconfigured devices.
|
||||
cdi = load_cdi_config(['/var/run/cdi', '/etc/cdi'])
|
||||
if platform.system() == "Windows":
|
||||
cdi = get_podman_machine_cdi_config()
|
||||
else:
|
||||
cdi = load_cdi_config(['/var/run/cdi', '/etc/cdi'])
|
||||
try:
|
||||
cdi_devices = cdi.get("devices", []) if cdi else []
|
||||
cdi_device_names = [name for cdi_device in cdi_devices if (name := cdi_device.get("name"))]
|
||||
@@ -712,7 +722,7 @@ def accel_image(config: Config, images: RamalamaImageConfig | None = None, conf_
|
||||
try:
|
||||
image = select_cuda_image(config)
|
||||
except NotImplementedError as e:
|
||||
logger.warn(f"{e}: Falling back to default image.")
|
||||
logger.warning(f"{e}: Falling back to default image.")
|
||||
image = config.default_image
|
||||
|
||||
vers = minor_release()
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# The following code is inspired from: https://github.com/ericcurtin/lm-pull/blob/main/lm-pull.py
|
||||
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
|
||||
# Import platform-specific locking mechanisms
|
||||
if sys.platform == 'win32':
|
||||
import msvcrt
|
||||
else:
|
||||
if platform.system() != "Windows":
|
||||
import fcntl
|
||||
|
||||
|
||||
@@ -22,30 +20,23 @@ class File:
|
||||
def lock(self):
|
||||
if self.file:
|
||||
self.fd = self.file.fileno()
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
# Windows file locking using msvcrt
|
||||
msvcrt.locking(self.fd, msvcrt.LK_NBLCK, 1)
|
||||
else:
|
||||
if platform.system() != "Windows":
|
||||
try:
|
||||
# Unix file locking using fcntl
|
||||
fcntl.flock(self.fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
except (BlockingIOError, OSError):
|
||||
self.fd = -1
|
||||
return 1
|
||||
|
||||
except (BlockingIOError, OSError):
|
||||
self.fd = -1
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def __del__(self):
|
||||
if self.fd >= 0:
|
||||
try:
|
||||
if sys.platform == 'win32':
|
||||
# Unlock on Windows
|
||||
msvcrt.locking(self.fd, msvcrt.LK_UNLCK, 1)
|
||||
else:
|
||||
if platform.system() != "Windows":
|
||||
try:
|
||||
# Unlock on Unix
|
||||
fcntl.flock(self.fd, fcntl.LOCK_UN)
|
||||
except OSError:
|
||||
pass # File may already be closed
|
||||
except OSError:
|
||||
pass # File may already be closed
|
||||
|
||||
if self.file:
|
||||
self.file.close()
|
||||
|
||||
@@ -24,7 +24,7 @@ class HttpClient:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def init(self, url, headers, output_file, show_progress, response_str=None):
|
||||
def init(self, url, headers, output_file, show_progress, response_bytes=None):
|
||||
output_file_partial = None
|
||||
if output_file:
|
||||
output_file_partial = output_file + ".partial"
|
||||
@@ -32,8 +32,8 @@ class HttpClient:
|
||||
self.file_size = self.set_resume_point(output_file_partial)
|
||||
self.urlopen(url, headers)
|
||||
self.total_to_download = int(self.response.getheader('content-length', 0))
|
||||
if response_str is not None:
|
||||
response_str.append(self.response.read().decode('utf-8'))
|
||||
if response_bytes is not None:
|
||||
response_bytes.append(self.response.read())
|
||||
else:
|
||||
out = File()
|
||||
try:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import platform
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ramalama.common import MNT_DIR, RAG_DIR, genname, get_accel_env_vars
|
||||
from ramalama.file import PlainFile
|
||||
from ramalama.path_utils import normalize_host_path_for_container
|
||||
from ramalama.version import version
|
||||
|
||||
|
||||
@@ -86,12 +88,16 @@ class Kube:
|
||||
return mounts, volumes
|
||||
|
||||
def _gen_path_volume(self):
|
||||
host_model_path = normalize_host_path_for_container(self.src_model_path)
|
||||
if platform.system() == "Windows":
|
||||
# Workaround https://github.com/containers/podman/issues/16704
|
||||
host_model_path = '/mnt' + host_model_path
|
||||
mount = f"""
|
||||
- mountPath: {self.dest_model_path}
|
||||
name: model"""
|
||||
volume = f"""
|
||||
- hostPath:
|
||||
path: {self.src_model_path}
|
||||
path: {host_model_path}
|
||||
name: model"""
|
||||
return mount, volume
|
||||
|
||||
@@ -116,22 +122,30 @@ class Kube:
|
||||
return mounts, volumes
|
||||
|
||||
def _gen_chat_template_volume(self):
|
||||
host_chat_template_path = normalize_host_path_for_container(self.src_chat_template_path)
|
||||
if platform.system() == "Windows":
|
||||
# Workaround https://github.com/containers/podman/issues/16704
|
||||
host_chat_template_path = '/mnt' + host_chat_template_path
|
||||
mount = f"""
|
||||
- mountPath: {self.dest_chat_template_path}
|
||||
name: chat_template"""
|
||||
volume = f"""
|
||||
- hostPath:
|
||||
path: {self.src_chat_template_path}
|
||||
path: {host_chat_template_path}
|
||||
name: chat_template"""
|
||||
return mount, volume
|
||||
|
||||
def _gen_mmproj_volume(self):
|
||||
host_mmproj_path = normalize_host_path_for_container(self.src_mmproj_path)
|
||||
if platform.system() == "Windows":
|
||||
# Workaround https://github.com/containers/podman/issues/16704
|
||||
host_mmproj_path = '/mnt' + host_mmproj_path
|
||||
mount = f"""
|
||||
- mountPath: {self.dest_mmproj_path}
|
||||
name: mmproj"""
|
||||
volume = f"""
|
||||
- hostPath:
|
||||
path: {self.src_mmproj_path}
|
||||
path: {host_mmproj_path}
|
||||
name: mmproj"""
|
||||
return mount, volume
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from enum import IntEnum
|
||||
from typing import Dict, Sequence
|
||||
|
||||
from ramalama.common import generate_sha256, perror
|
||||
from ramalama.common import generate_sha256_binary, perror
|
||||
from ramalama.http_client import download_file
|
||||
from ramalama.logger import logger
|
||||
|
||||
@@ -57,7 +57,7 @@ class SnapshotFile:
|
||||
class LocalSnapshotFile(SnapshotFile):
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
content: bytes,
|
||||
name: str,
|
||||
type: SnapshotFileType,
|
||||
should_show_progress: bool = False,
|
||||
@@ -67,7 +67,7 @@ class LocalSnapshotFile(SnapshotFile):
|
||||
super().__init__(
|
||||
"",
|
||||
{},
|
||||
generate_sha256(content),
|
||||
generate_sha256_binary(content),
|
||||
name,
|
||||
type,
|
||||
should_show_progress,
|
||||
@@ -77,7 +77,7 @@ class LocalSnapshotFile(SnapshotFile):
|
||||
self.content = content
|
||||
|
||||
def download(self, blob_file_path, snapshot_dir):
|
||||
with open(blob_file_path, "w") as file:
|
||||
with open(blob_file_path, "wb") as file:
|
||||
file.write(self.content)
|
||||
file.flush()
|
||||
return os.path.relpath(blob_file_path, start=snapshot_dir)
|
||||
|
||||
@@ -266,7 +266,11 @@ class ModelStore:
|
||||
logger.debug(f"Failed to convert template: {e}")
|
||||
continue
|
||||
|
||||
files = [LocalSnapshotFile(normalized_template, "chat_template_converted", SnapshotFileType.ChatTemplate)]
|
||||
files = [
|
||||
LocalSnapshotFile(
|
||||
normalized_template.encode("utf-8"), "chat_template_converted", SnapshotFileType.ChatTemplate
|
||||
)
|
||||
]
|
||||
self._update_snapshot(ref_file, snapshot_hash, files)
|
||||
return True
|
||||
|
||||
@@ -302,7 +306,7 @@ class ModelStore:
|
||||
# chat template if it is a Go Template (ollama-specific)
|
||||
files = [
|
||||
LocalSnapshotFile(
|
||||
tmpl,
|
||||
tmpl.encode("utf-8"),
|
||||
"chat_template_extracted",
|
||||
SnapshotFileType.Other if needs_conversion else SnapshotFileType.ChatTemplate,
|
||||
)
|
||||
@@ -311,7 +315,9 @@ class ModelStore:
|
||||
try:
|
||||
desired_template = convert_go_to_jinja(tmpl)
|
||||
files.append(
|
||||
LocalSnapshotFile(desired_template, "chat_template_converted", SnapshotFileType.ChatTemplate)
|
||||
LocalSnapshotFile(
|
||||
desired_template.encode("utf-8"), "chat_template_converted", SnapshotFileType.ChatTemplate
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to convert Go Template to Jinja: {ex}")
|
||||
@@ -397,6 +403,7 @@ class ModelStore:
|
||||
return True
|
||||
|
||||
def _remove_blob_file(self, snapshot_file_path: str):
|
||||
# FIXME: this assumes a symlink, fails on Windows
|
||||
blob_path = Path(snapshot_file_path).resolve()
|
||||
try:
|
||||
if os.path.exists(blob_path) and Path(self.base_path) in blob_path.parents:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Utilities for cross-platform path handling, especially for Windows Docker/Podman support."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
from pathlib import Path, PureWindowsPath
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def normalize_host_path_for_container(host_path: str) -> str:
|
||||
Windows: "C:\\Users\\John\\models" -> "/c/Users/John/models"
|
||||
Linux: "/home/john/models" -> "/home/john/models"
|
||||
"""
|
||||
if sys.platform != 'win32':
|
||||
if platform.system() != "Windows":
|
||||
# On Linux/macOS, paths are already in the correct format
|
||||
return host_path
|
||||
|
||||
@@ -37,8 +37,13 @@ def normalize_host_path_for_container(host_path: str) -> str:
|
||||
# First, resolve symlinks and make the path absolute.
|
||||
path = Path(host_path).resolve()
|
||||
|
||||
# Handle UNC paths (e.g., \\server\share). On Windows, these resolve to
|
||||
# absolute paths but don't have a drive letter.
|
||||
# Handle UNC paths to container filesystem
|
||||
# e.g if the model store is placed on the podman machine VM to reduce copying
|
||||
# \\wsl.localhost\podman-machine-default\home\user\.local\share\ramalama\store
|
||||
# NOTE: UNC paths cannot be accessed implicitly from the container, would need to smb mount
|
||||
if path.drive.startswith("\\\\"):
|
||||
return '/' + path.relative_to(path.drive).as_posix()
|
||||
|
||||
if not path.drive:
|
||||
return path.as_posix()
|
||||
|
||||
@@ -58,7 +63,7 @@ def is_windows_absolute_path(path: str) -> bool:
|
||||
Returns:
|
||||
True if the path looks like a Windows absolute path (e.g., C:\\, D:\\)
|
||||
"""
|
||||
if sys.platform != 'win32':
|
||||
if platform.system() != "Windows":
|
||||
return False
|
||||
|
||||
return PureWindowsPath(path).is_absolute()
|
||||
|
||||
@@ -8,7 +8,7 @@ from ramalama.chat import ChatOperationalArgs
|
||||
from ramalama.common import accel_image, perror, set_accel_env_vars
|
||||
from ramalama.compat import StrEnum
|
||||
from ramalama.config import Config
|
||||
from ramalama.engine import BuildEngine, Engine, is_healthy, wait_for_healthy
|
||||
from ramalama.engine import BuildEngine, Engine, is_healthy, stop_container, wait_for_healthy
|
||||
from ramalama.path_utils import get_container_mount_path
|
||||
from ramalama.transports.base import Transport
|
||||
from ramalama.transports.oci import OCI
|
||||
@@ -177,25 +177,41 @@ class RagTransport(OCI):
|
||||
args.rag = None
|
||||
super()._handle_container_chat(args, pid)
|
||||
|
||||
def _start_model(self, args, cmd: list[str]):
|
||||
pid = self.imodel._fork_and_serve(args, self.model_cmd)
|
||||
if pid:
|
||||
_, status = os.waitpid(pid, 0)
|
||||
if status != 0:
|
||||
raise subprocess.CalledProcessError(
|
||||
os.waitstatus_to_exitcode(status),
|
||||
" ".join(cmd),
|
||||
)
|
||||
return pid
|
||||
|
||||
def serve(self, args, cmd: list[str]):
|
||||
pid = self._start_model(args.model_args, cmd)
|
||||
if pid:
|
||||
args.model_args.name = self.imodel.get_container_name(args.model_args)
|
||||
process = self.imodel.serve_nonblocking(args.model_args, self.model_cmd)
|
||||
if not args.dryrun:
|
||||
if process and process.wait() != 0:
|
||||
raise subprocess.CalledProcessError(
|
||||
process.returncode,
|
||||
" ".join(self.model_cmd),
|
||||
)
|
||||
try:
|
||||
super().serve(args, cmd)
|
||||
finally:
|
||||
if getattr(args.model_args, "name", None):
|
||||
args.model_args.ignore = True
|
||||
stop_container(args.model_args, args.model_args.name, remove=True)
|
||||
|
||||
def run(self, args, cmd: list[str]):
|
||||
args.model_args.name = self.imodel.get_container_name(args.model_args)
|
||||
super().run(args, cmd)
|
||||
process = self.imodel.serve_nonblocking(args.model_args, self.model_cmd)
|
||||
rag_process = self.serve_nonblocking(args, cmd)
|
||||
|
||||
if args.dryrun:
|
||||
return
|
||||
|
||||
if process and process.wait() != 0:
|
||||
raise subprocess.CalledProcessError(
|
||||
process.returncode,
|
||||
" ".join(self.model_cmd),
|
||||
)
|
||||
if rag_process and rag_process.wait() != 0:
|
||||
raise subprocess.CalledProcessError(
|
||||
rag_process.returncode,
|
||||
" ".join(cmd),
|
||||
)
|
||||
return self._connect_and_chat(args, rag_process)
|
||||
|
||||
def wait_for_healthy(self, args):
|
||||
self.imodel.wait_for_healthy(args.model_args)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
import tempfile
|
||||
import platform
|
||||
|
||||
import ramalama.kube as kube
|
||||
import ramalama.quadlet as quadlet
|
||||
from ramalama.common import check_nvidia, exec_cmd, genname, get_accel_env_vars, tagged_image
|
||||
from ramalama.compat import NamedTemporaryFile
|
||||
from ramalama.config import CONFIG
|
||||
from ramalama.engine import add_labels
|
||||
from ramalama.path_utils import normalize_host_path_for_container
|
||||
from ramalama.transports.base import compute_serving_port
|
||||
from ramalama.transports.transport_factory import New
|
||||
|
||||
@@ -50,7 +52,7 @@ class Stack:
|
||||
- mountPath: {self.model._get_entry_model_path(True, True, False)}
|
||||
name: model"""
|
||||
|
||||
if self.args.dri == "on":
|
||||
if self.args.dri == "on" and platform.system() != "Windows":
|
||||
volume_mounts += """
|
||||
- mountPath: /dev/dri
|
||||
name: dri"""
|
||||
@@ -58,9 +60,13 @@ class Stack:
|
||||
return volume_mounts
|
||||
|
||||
def _gen_volumes(self):
|
||||
host_model_path = normalize_host_path_for_container(self.model._get_entry_model_path(False, False, False))
|
||||
if platform.system() == "Windows":
|
||||
# Workaround https://github.com/containers/podman/issues/16704
|
||||
host_model_path = '/mnt' + host_model_path
|
||||
volumes = f"""
|
||||
- hostPath:
|
||||
path: {self.model._get_entry_model_path(False, False, False)}
|
||||
path: {host_model_path}
|
||||
name: model"""
|
||||
if self.args.dri == "on":
|
||||
volumes += """
|
||||
@@ -209,34 +215,35 @@ spec:
|
||||
k.write(self.args.generate.output_dir)
|
||||
return
|
||||
|
||||
yaml_file = tempfile.NamedTemporaryFile(prefix='RamaLama_', delete=not self.args.debug)
|
||||
with open(yaml_file.name, 'w') as c:
|
||||
c.write(yaml)
|
||||
c.flush()
|
||||
with NamedTemporaryFile(
|
||||
mode='w', prefix='RamaLama_', delete=not self.args.debug, delete_on_close=False
|
||||
) as yaml_file:
|
||||
yaml_file.write(yaml)
|
||||
yaml_file.close()
|
||||
|
||||
exec_args = [
|
||||
self.args.engine,
|
||||
"kube",
|
||||
"play",
|
||||
"--replace",
|
||||
]
|
||||
if not self.args.detach:
|
||||
exec_args.append("--wait")
|
||||
exec_args = [
|
||||
self.args.engine,
|
||||
"kube",
|
||||
"play",
|
||||
"--replace",
|
||||
]
|
||||
if not self.args.detach:
|
||||
exec_args.append("--wait")
|
||||
|
||||
exec_args.append(yaml_file.name)
|
||||
exec_cmd(exec_args)
|
||||
exec_args.append(yaml_file.name)
|
||||
exec_cmd(exec_args)
|
||||
|
||||
def stop(self):
|
||||
yaml_file = tempfile.NamedTemporaryFile(prefix='RamaLama_', delete=not self.args.debug)
|
||||
with open(yaml_file.name, 'w') as c:
|
||||
c.write(self.generate())
|
||||
c.flush()
|
||||
with NamedTemporaryFile(
|
||||
mode='w', prefix='RamaLama_', delete=not self.args.debug, delete_on_close=False
|
||||
) as yaml_file:
|
||||
yaml_file.write(self.generate())
|
||||
yaml_file.close()
|
||||
|
||||
exec_args = [
|
||||
self.args.engine,
|
||||
"kube",
|
||||
"down",
|
||||
yaml_file.name,
|
||||
]
|
||||
|
||||
exec_cmd(exec_args)
|
||||
exec_args = [
|
||||
self.args.engine,
|
||||
"kube",
|
||||
"down",
|
||||
yaml_file.name,
|
||||
]
|
||||
exec_cmd(exec_args)
|
||||
|
||||
@@ -390,39 +390,60 @@ class Transport(TransportBase):
|
||||
# The Run command will first launch a daemonized service
|
||||
# and run chat to communicate with it.
|
||||
|
||||
args.noout = not args.debug
|
||||
process = self.serve_nonblocking(args, server_cmd)
|
||||
if process:
|
||||
return self._connect_and_chat(args, process)
|
||||
|
||||
pid = self._fork_and_serve(args, server_cmd)
|
||||
if pid:
|
||||
return self._connect_and_chat(args, pid)
|
||||
|
||||
def _fork_and_serve(self, args, cmd: list[str]):
|
||||
def serve_nonblocking(self, args, cmd: list[str]) -> subprocess.Popen | None:
|
||||
if args.container:
|
||||
args.name = self.get_container_name(args)
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
# Child process - start the server
|
||||
self._start_server(args, cmd)
|
||||
return pid
|
||||
|
||||
def _start_server(self, args, cmd: list[str]):
|
||||
"""Start the server in the child process."""
|
||||
# Use subprocess.Popen for all platforms
|
||||
# Prepare args for the server
|
||||
args.host = CONFIG.host
|
||||
args.generate = ""
|
||||
args.detach = True
|
||||
self.serve(args, cmd)
|
||||
|
||||
def _connect_and_chat(self, args, server_pid):
|
||||
set_accel_env_vars()
|
||||
|
||||
if args.container:
|
||||
# For container mode, set up the container and start it with subprocess
|
||||
self.setup_container(args)
|
||||
self.setup_mounts(args)
|
||||
# Make sure Image precedes cmd_args
|
||||
self.engine.add([args.image] + cmd)
|
||||
|
||||
if args.dryrun:
|
||||
self.engine.dryrun()
|
||||
return None
|
||||
|
||||
# Start the container using subprocess.Popen
|
||||
process = subprocess.Popen(
|
||||
self.engine.exec_args,
|
||||
)
|
||||
return process
|
||||
|
||||
# Non-container mode: run the command directly with subprocess
|
||||
if args.dryrun:
|
||||
dry_run(cmd)
|
||||
return None
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
)
|
||||
return process
|
||||
|
||||
def _connect_and_chat(self, args, server_process):
|
||||
"""Connect to the server and start chat in the parent process."""
|
||||
args.url = f"http://127.0.0.1:{args.port}/v1"
|
||||
if getattr(args, "runtime", None) == "mlx":
|
||||
args.prefix = "🍏 > "
|
||||
args.pid2kill = ""
|
||||
|
||||
if args.container:
|
||||
return self._handle_container_chat(args, server_pid)
|
||||
return self._handle_container_chat(args, server_process)
|
||||
else:
|
||||
args.pid2kill = server_pid
|
||||
# Store the Popen object for monitoring
|
||||
args.server_process = server_process
|
||||
|
||||
if getattr(args, "runtime", None) == "mlx":
|
||||
return self._handle_mlx_chat(args)
|
||||
chat.chat(args)
|
||||
@@ -434,10 +455,11 @@ class Transport(TransportBase):
|
||||
def wait_for_healthy(self, args):
|
||||
wait_for_healthy(args, is_healthy)
|
||||
|
||||
def _handle_container_chat(self, args, server_pid):
|
||||
def _handle_container_chat(self, args, server_process):
|
||||
"""Handle chat for container-based execution."""
|
||||
_, status = os.waitpid(server_pid, 0)
|
||||
if status != 0:
|
||||
# Wait for the server process to complete (blocking)
|
||||
exit_code = server_process.wait()
|
||||
if exit_code != 0:
|
||||
raise ValueError(f"Failed to serve model {self.model_name}, for ramalama run command")
|
||||
|
||||
if not args.dryrun:
|
||||
@@ -480,7 +502,7 @@ class Transport(TransportBase):
|
||||
except Exception as e:
|
||||
if i >= max_retries - 1:
|
||||
perror(f"Error: Failed to connect to MLX server after {max_retries} attempts: {e}")
|
||||
self._cleanup_server_process(args.pid2kill)
|
||||
self._cleanup_server_process(args.server_process)
|
||||
raise e
|
||||
logger.debug(f"Connection attempt failed, retrying... (attempt {i + 1}/{max_retries}): {e}")
|
||||
time.sleep(3)
|
||||
@@ -498,32 +520,16 @@ class Transport(TransportBase):
|
||||
except (socket.error, ValueError):
|
||||
return False
|
||||
|
||||
def _cleanup_server_process(self, pid):
|
||||
def _cleanup_server_process(self, process):
|
||||
"""Clean up the server process."""
|
||||
if not pid:
|
||||
if not process:
|
||||
return
|
||||
|
||||
import signal
|
||||
|
||||
process.terminate()
|
||||
try:
|
||||
# Try graceful termination first
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
time.sleep(1) # Give it time to terminate gracefully
|
||||
|
||||
# Force kill if still running (SIGKILL is Unix-only)
|
||||
if hasattr(signal, 'SIGKILL'):
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
else:
|
||||
# On Windows, send SIGTERM again as fallback
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
def perplexity(self, args, cmd: list[str]):
|
||||
set_accel_env_vars()
|
||||
@@ -636,7 +642,11 @@ class Transport(TransportBase):
|
||||
self.generate_container_config(args, cmd)
|
||||
return
|
||||
|
||||
self.execute_command(cmd, args)
|
||||
try:
|
||||
self.execute_command(cmd, args)
|
||||
except Exception as e:
|
||||
self._cleanup_server_process(args.server_process)
|
||||
raise e
|
||||
|
||||
def quadlet(self, model_paths, chat_template_paths, mmproj_paths, args, exec_args, output_dir):
|
||||
quadlet = Quadlet(self.model_name, model_paths, chat_template_paths, mmproj_paths, args, exec_args)
|
||||
|
||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from test.conftest import ramalama_container_engine
|
||||
|
||||
import bcrypt
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -55,11 +56,9 @@ def container_registry():
|
||||
shutil.copy(work_dir / "domain.crt", trusted_certs_dir)
|
||||
|
||||
# Create htpasswd file
|
||||
subprocess.run(
|
||||
f"htpasswd -Bbn {registry_username} {registry_password} > {htpasswd_file.as_posix()}",
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
with open(htpasswd_file, "w") as pwfile:
|
||||
passwd_hash = bcrypt.hashpw(registry_password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
pwfile.write(f"{registry_username}:{passwd_hash}")
|
||||
|
||||
# Start the registry
|
||||
subprocess.run(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from test.conftest import xfail_if_windows
|
||||
from test.e2e.utils import RamalamaExecWorkspace
|
||||
|
||||
import pytest
|
||||
@@ -60,7 +59,6 @@ def test_json_output(shared_ctx):
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: Exception: Failed to remove the following models: ollama\library\smollm://:135m
|
||||
def test_all_images_removed(shared_ctx):
|
||||
shared_ctx.check_call(["ramalama", "rm", "-a"])
|
||||
result = shared_ctx.check_output(["ramalama", "list", "--noheading"])
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import random
|
||||
import re
|
||||
from subprocess import STDOUT, CalledProcessError
|
||||
from test.conftest import xfail_if_windows
|
||||
from test.e2e.utils import check_output
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: AttributeError: module 'os' has no attribute 'fork'
|
||||
def test_delete_non_existing_image():
|
||||
image_name = f"rm_random_image_{random.randint(0, 9999)}"
|
||||
with pytest.raises(CalledProcessError) as exc_info:
|
||||
@@ -16,7 +14,7 @@ def test_delete_non_existing_image():
|
||||
|
||||
assert exc_info.value.returncode == 22
|
||||
assert re.match(
|
||||
f"Error: Model '{image_name}' not found\n",
|
||||
f"Error: Model '{image_name}' not found",
|
||||
exc_info.value.output.decode("utf-8"),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import json
|
||||
import platform
|
||||
import re
|
||||
from subprocess import STDOUT, CalledProcessError
|
||||
from subprocess import PIPE, STDOUT, CalledProcessError
|
||||
from test.conftest import (
|
||||
skip_if_container,
|
||||
skip_if_darwin,
|
||||
skip_if_docker,
|
||||
skip_if_gh_actions_darwin,
|
||||
skip_if_no_container,
|
||||
xfail_if_windows,
|
||||
)
|
||||
from test.e2e.utils import RamalamaExecWorkspace, check_output
|
||||
|
||||
@@ -38,177 +37,175 @@ def shared_ctx_with_models(test_model):
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: AttributeError: module 'os' has no attribute 'fork'
|
||||
@skip_if_no_container
|
||||
def test_basic_dry_run():
|
||||
ramalama_info = json.loads(check_output(["ramalama", "info"]))
|
||||
conman = ramalama_info["Engine"]["Name"]
|
||||
|
||||
result = check_output(["ramalama", "-q", "--dryrun", "run", TEST_MODEL])
|
||||
result = check_output(["ramalama", "-q", "--dryrun", "run", TEST_MODEL], stdin=PIPE)
|
||||
assert not result.startswith(f"{conman} run --rm")
|
||||
assert not re.search(r".*-t -i", result), "run without terminal"
|
||||
|
||||
result = check_output(["ramalama", "-q", "--dryrun", "run", TEST_MODEL, "what's up doc?"])
|
||||
result = check_output(["ramalama", "-q", "--dryrun", "run", TEST_MODEL, "what's up doc?"], stdin=PIPE)
|
||||
assert result.startswith(f"{conman} run")
|
||||
assert not re.search(r".*-t -i", result), "run without terminal"
|
||||
|
||||
result = check_output(f'echo "Test" | ramalama -q --dryrun run {TEST_MODEL}', shell=True)
|
||||
result = check_output(f'echo "Test" | ramalama -q --dryrun run {TEST_MODEL}', shell=True, stdin=PIPE)
|
||||
assert result.startswith(f"{conman} run")
|
||||
assert not re.search(r".*-t -i", result), "run without terminal"
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: AttributeError: module 'os' has no attribute 'fork'
|
||||
@pytest.mark.parametrize(
|
||||
"extra_params, pattern, config, env_vars, expected",
|
||||
"extra_params, pattern, config, env_vars, expected, stdin",
|
||||
[
|
||||
# fmt: off
|
||||
pytest.param(
|
||||
[], f".*{TEST_MODEL_FULL_NAME}.*", None, None, True,
|
||||
[], f".*{TEST_MODEL_FULL_NAME}.*", None, None, True, None,
|
||||
id="check test_model", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*--cache-reuse 256", None, None, True,
|
||||
[], r".*--cache-reuse 256", None, None, True, None,
|
||||
id="check cache-reuse is being set", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*--ctx-size", None, None, False,
|
||||
[], r".*--ctx-size", None, None, False, None,
|
||||
id="check ctx-size is not show by default", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*--seed", None, None, False,
|
||||
[], r".*--seed", None, None, False, None,
|
||||
id="check --seed is not set by default", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*-t -i", None, None, False,
|
||||
[], r".*-t -i",None, None, False, PIPE,
|
||||
id="check -t -i is not present without tty", marks=skip_if_no_container)
|
||||
,
|
||||
pytest.param(
|
||||
["--env", "a=b", "--env", "test=success", "--name", "foobar"],
|
||||
r"--env a=b --env test=success", None, None, True,
|
||||
r"--env a=b --env test=success", None, None, True, None,
|
||||
id="check --env", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
["--oci-runtime", "foobar"], r"--runtime foobar", None, None, True,
|
||||
["--oci-runtime", "foobar"], r"--runtime foobar", None, None, True, None,
|
||||
id="check --oci-runtime", marks=skip_if_no_container)
|
||||
,
|
||||
pytest.param(
|
||||
["--net", "bridge", "--name", "foobar"], r".*--network bridge",
|
||||
None, {"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True,
|
||||
None, {"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True, None,
|
||||
id="check --net=bridge with RAMALAMA_CONFIG=/dev/null", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
["--name", "foobar"], f".*{TEST_MODEL_FULL_NAME}.*",
|
||||
None, {"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True,
|
||||
None, {"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True, None,
|
||||
id="check test_model with RAMALAMA_CONFIG=/dev/null", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
["-c", "4096", "--name", "foobar"], r".*--ctx-size 4096",
|
||||
None, {"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True,
|
||||
None, {"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True, None,
|
||||
id="check --ctx-size 4096 with RAMALAMA_CONFIG=/dev/null", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
["--cache-reuse", "512", "--name", "foobar"], r".*--cache-reuse 512", None,
|
||||
{"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True,
|
||||
{"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True, None,
|
||||
id="check --cache-reuse with RAMALAMA_CONFIG=/dev/null", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
["--name", "foobar"], r".*--temp 0.8", None, {
|
||||
"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'
|
||||
}, True,
|
||||
}, True, None,
|
||||
id="check --temp default value is 0.8 with RAMALAMA_CONFIG=/dev/null", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
["--seed", "9876", "--name", "foobar"], r".*--seed 9876",
|
||||
None, {"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True,
|
||||
None, {"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True, None,
|
||||
id="check --seed 9876 with RAMALAMA_CONFIG=/dev/null", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
["--name", "foobar"], r".*--pull newer", None,
|
||||
{"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True,
|
||||
{"RAMALAMA_CONFIG": "NUL" if platform.system() == "Windows" else '/dev/null'}, True, None,
|
||||
id="check pull policy with RAMALAMA_CONFIG=/dev/null", marks=[skip_if_no_container, skip_if_docker],
|
||||
),
|
||||
pytest.param(
|
||||
[], DEFAULT_PULL_PATTERN, None, None, True,
|
||||
[], DEFAULT_PULL_PATTERN, None, None, True, None,
|
||||
id="check default pull policy",
|
||||
marks=[skip_if_no_container],
|
||||
),
|
||||
pytest.param(
|
||||
["--pull", "never", "-c", "4096", "--name", "foobbar"], r".*--pull never", None, None, True,
|
||||
["--pull", "never", "-c", "4096", "--name", "foobbar"], r".*--pull never", None, None, True, None,
|
||||
id="check never pull policy", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*--pull never", CONFIG_WITH_PULL_NEVER, None, True,
|
||||
[], r".*--pull never", CONFIG_WITH_PULL_NEVER, None, True, None,
|
||||
id="check pull policy with RAMALAMA_CONFIG", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
["--name", "foobar"], r".*--name foobar", None, None, True,
|
||||
["--name", "foobar"], r".*--name foobar", None, None, True, None,
|
||||
id="check --name foobar", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
["--name", "foobar"], r".*--cap-drop=all", None, None, True,
|
||||
["--name", "foobar"], r".*--cap-drop=all", None, None, True, None,
|
||||
id="check if --cap-drop=all is present", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
["--name", "foobar"], r".*no-new-privileges", None, None, True,
|
||||
["--name", "foobar"], r".*no-new-privileges", None, None, True, None,
|
||||
id="check if --no-new-privs is present", marks=skip_if_no_container),
|
||||
pytest.param(
|
||||
["--selinux", "True"], r".*--security-opt=label=disable", None, None, False,
|
||||
["--selinux", "True"], r".*--security-opt=label=disable", None, None, False, None,
|
||||
id="check --selinux=True enables container separation", marks=skip_if_no_container),
|
||||
pytest.param(
|
||||
["--selinux", "False"], r".*--security-opt=label=disable", None, None, True,
|
||||
["--selinux", "False"], r".*--security-opt=label=disable", None, None, True, None,
|
||||
id="check --selinux=False disables container separation", marks=skip_if_no_container),
|
||||
pytest.param(
|
||||
["--runtime-args", "--foo -bar"], r".*--foo\s+-bar", None, None, True,
|
||||
["--runtime-args", "--foo -bar"], r".*--foo\s+-bar", None, None, True, None,
|
||||
id="check --runtime-args", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
["--runtime-args", "--foo='a b c'"], r".*--foo=a b c", None, None, True,
|
||||
["--runtime-args", "--foo='a b c'"], r".*--foo=a b c", None, None, True, None,
|
||||
id="check --runtime-args=\"--foo='a b c'\"", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
["--privileged"], r".*--privileged", None, None, True,
|
||||
["--privileged"], r".*--privileged", None, None, True, None,
|
||||
id="check --privileged", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
["--privileged"], r".*--cap-drop=all", None, None, False,
|
||||
["--privileged"], r".*--cap-drop=all", None, None, False, None,
|
||||
id="check cap-drop=all is not set when --privileged", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
["--privileged"], r".*no-new-privileges", None, None, False,
|
||||
["--privileged"], r".*no-new-privileges", None, None, False, None,
|
||||
id="check no-new-privileges is not set when --privileged", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*foo:latest.*serve", None, {"RAMALAMA_IMAGE": "foo:latest"}, True,
|
||||
[], r".*foo:latest.*serve", None, {"RAMALAMA_IMAGE": "foo:latest"}, True, None,
|
||||
id="check run with RAMALAMA_IMAGE=foo:latest", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
["--ctx-size", "4096"], r".*serve.*--ctx-size 4096", None, None, True,
|
||||
["--ctx-size", "4096"], r".*serve.*--ctx-size 4096", None, None, True, None,
|
||||
id="check --ctx-size 4096", marks=skip_if_container,
|
||||
),
|
||||
pytest.param(
|
||||
["--ctx-size", "4096"], r".*--cache-reuse 256.*", None, None, True,
|
||||
["--ctx-size", "4096"], r".*--cache-reuse 256.*", None, None, True, None,
|
||||
id="check --cache-reuse is set by default to 256", marks=skip_if_container,
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*-e ASAHI_VISIBLE_DEVICES=99", None, {"ASAHI_VISIBLE_DEVICES": "99"}, True,
|
||||
[], r".*-e ASAHI_VISIBLE_DEVICES=99", None, {"ASAHI_VISIBLE_DEVICES": "99"}, True, None,
|
||||
id="check ASAHI_VISIBLE_DEVICES env var", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*-e CUDA_LAUNCH_BLOCKING=1", None, {"CUDA_LAUNCH_BLOCKING": "1"}, True,
|
||||
[], r".*-e CUDA_LAUNCH_BLOCKING=1", None, {"CUDA_LAUNCH_BLOCKING": "1"}, True, None,
|
||||
id="check CUDA_LAUNCH_BLOCKING env var", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*-e HIP_VISIBLE_DEVICES=99", None, {"HIP_VISIBLE_DEVICES": "99"}, True,
|
||||
[], r".*-e HIP_VISIBLE_DEVICES=99", None, {"HIP_VISIBLE_DEVICES": "99"}, True, None,
|
||||
id="check HIP_VISIBLE_DEVICES env var", marks=skip_if_no_container
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*-e HSA_OVERRIDE_GFX_VERSION=0.0.0", None, {"HSA_OVERRIDE_GFX_VERSION": "0.0.0"}, True,
|
||||
[], r".*-e HSA_OVERRIDE_GFX_VERSION=0.0.0", None, {"HSA_OVERRIDE_GFX_VERSION": "0.0.0"}, True, None,
|
||||
id="check HSA_OVERRIDE_GFX_VERSION env var", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
[], r"(.*-e (HIP_VISIBLE_DEVICES=99|HSA_OVERRIDE_GFX_VERSION=0.0.0)){2}",
|
||||
None, {"HIP_VISIBLE_DEVICES": "99", "HSA_OVERRIDE_GFX_VERSION": "0.0.0"}, True,
|
||||
None, {"HIP_VISIBLE_DEVICES": "99", "HSA_OVERRIDE_GFX_VERSION": "0.0.0"}, True, None,
|
||||
id="check HIP_VISIBLE_DEVICES & HSA_OVERRIDE_GFX_VERSION env vars", marks=skip_if_no_container,
|
||||
),
|
||||
pytest.param(
|
||||
@@ -216,17 +213,17 @@ def test_basic_dry_run():
|
||||
"--device", "NUL" if platform.system() == "Windows" else '/dev/null',
|
||||
"--pull", "never"
|
||||
],
|
||||
r".*--device (NUL|/dev/null) .*", None, None, True,
|
||||
r".*--device (NUL|/dev/null) .*", None, None, True, None,
|
||||
id="check --device=/dev/null", marks=skip_if_no_container),
|
||||
pytest.param(
|
||||
["--device", "none", "--pull", "never"], r".*--device.*", None, None, False,
|
||||
["--device", "none", "--pull", "never"], r".*--device.*", None, None, False, None,
|
||||
id="check --device with unsupported value", marks=skip_if_no_container),
|
||||
# fmt: on
|
||||
],
|
||||
)
|
||||
def test_params(extra_params, pattern, config, env_vars, expected):
|
||||
def test_params(extra_params, pattern, config, env_vars, expected, stdin):
|
||||
with RamalamaExecWorkspace(config=config, env_vars=env_vars) as ctx:
|
||||
result = ctx.check_output(RAMALAMA_DRY_RUN + extra_params + [TEST_MODEL])
|
||||
result = ctx.check_output(RAMALAMA_DRY_RUN + extra_params + [TEST_MODEL], stdin=stdin)
|
||||
assert bool(re.search(pattern, result)) is expected
|
||||
|
||||
|
||||
@@ -271,7 +268,6 @@ def test_params_errors(extra_params, pattern, config, env_vars, expected_exit_co
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: AttributeError: module 'os' has no attribute 'fork'
|
||||
@skip_if_darwin # test is broken on MAC --no-container right now
|
||||
def test_run_model_with_prompt(shared_ctx_with_models, test_model):
|
||||
import platform
|
||||
@@ -288,14 +284,12 @@ def test_run_model_with_prompt(shared_ctx_with_models, test_model):
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: AttributeError: module 'os' has no attribute 'fork'
|
||||
def test_run_keepalive(shared_ctx_with_models, test_model):
|
||||
ctx = shared_ctx_with_models
|
||||
ctx.check_call(["ramalama", "run", "--keepalive", "1s", test_model])
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: AttributeError: module 'os' has no attribute 'fork'
|
||||
@skip_if_no_container
|
||||
@skip_if_docker
|
||||
@skip_if_gh_actions_darwin
|
||||
@@ -319,7 +313,7 @@ def test_run_keepalive(shared_ctx_with_models, test_model):
|
||||
"tiny",
|
||||
],
|
||||
22,
|
||||
r".*Error: quay.io/ramalama/testrag: image not known",
|
||||
r".*quay.io/ramalama/testrag: image not known",
|
||||
id="non-existing-image-with-rag",
|
||||
),
|
||||
],
|
||||
@@ -333,7 +327,6 @@ def test_run_with_non_existing_images_new(shared_ctx_with_models, run_args, exit
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: AttributeError: module 'os' has no attribute 'fork'
|
||||
@skip_if_no_container
|
||||
@skip_if_darwin
|
||||
@skip_if_docker
|
||||
|
||||
@@ -16,7 +16,6 @@ from test.conftest import (
|
||||
skip_if_gh_actions_darwin,
|
||||
skip_if_no_container,
|
||||
skip_if_not_darwin,
|
||||
xfail_if_windows,
|
||||
)
|
||||
from test.e2e.utils import RamalamaExecWorkspace, check_output, get_full_model_name
|
||||
|
||||
@@ -288,7 +287,6 @@ def test_full_model_name_expansion():
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: Error: no container with name or ID "serve_and_stop_dyGXy" found: no such container
|
||||
@skip_if_no_container
|
||||
def test_serve_and_stop(shared_ctx, test_model):
|
||||
ctx = shared_ctx
|
||||
@@ -334,7 +332,6 @@ def test_serve_and_stop(shared_ctx, test_model):
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: Container not starting?
|
||||
@skip_if_no_container
|
||||
def test_serve_multiple_models(shared_ctx, test_model):
|
||||
ctx = shared_ctx
|
||||
@@ -459,7 +456,6 @@ def test_generation_with_bad_add_to_unit_flag_value(test_model):
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: registry fixture currently doesn't work on windows
|
||||
@skip_if_no_container
|
||||
@pytest.mark.xfail("config.option.container_engine == 'docker'", reason="docker login does not support --tls-verify")
|
||||
def test_quadlet_and_kube_generation_with_container_registry(container_registry, is_container, test_model):
|
||||
@@ -681,7 +677,6 @@ def test_kube_generation_with_llama_api(test_model):
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: failing with exit code 5
|
||||
@skip_if_docker
|
||||
@skip_if_no_container
|
||||
def test_serve_api(caplog):
|
||||
@@ -730,7 +725,6 @@ def test_serve_api(caplog):
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: AttributeError: module 'os' has no attribute 'fork'
|
||||
@skip_if_no_container
|
||||
@skip_if_docker
|
||||
@skip_if_gh_actions_darwin
|
||||
@@ -757,11 +751,10 @@ def test_serve_with_non_existing_images():
|
||||
stderr=STDOUT,
|
||||
)
|
||||
assert exc_info.value.returncode == 22
|
||||
assert re.search(r"Error: quay.io/ramalama/rag: image not known.*", exc_info.value.output.decode("utf-8"))
|
||||
assert re.search(r"quay.io/ramalama/rag: image not known.*", exc_info.value.output.decode("utf-8"))
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
@xfail_if_windows # FIXME: AttributeError: module 'os' has no attribute 'fork'
|
||||
@skip_if_no_container
|
||||
@skip_if_darwin
|
||||
@skip_if_docker
|
||||
|
||||
@@ -2,9 +2,15 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from ramalama.common import generate_sha256_binary
|
||||
from ramalama.model_store.global_store import GlobalModelStore
|
||||
from ramalama.model_store.reffile import RefJSONFile, StoreFile, StoreFileType
|
||||
from ramalama.model_store.snapshot_file import SnapshotFile, SnapshotFileType, validate_snapshot_files
|
||||
from ramalama.model_store.snapshot_file import (
|
||||
LocalSnapshotFile,
|
||||
SnapshotFile,
|
||||
SnapshotFileType,
|
||||
validate_snapshot_files,
|
||||
)
|
||||
from ramalama.model_store.store import ModelStore
|
||||
from ramalama.model_store.template_conversion import wrap_template_with_messages_loop
|
||||
|
||||
@@ -81,4 +87,28 @@ def test_try_convert_existing_chat_template_converts_flat_jinja(tmp_path, monkey
|
||||
assert len(captured["files"]) == 1
|
||||
converted_file = captured["files"][0]
|
||||
assert converted_file.type == SnapshotFileType.ChatTemplate
|
||||
assert converted_file.content == wrap_template_with_messages_loop(original_template)
|
||||
assert converted_file.content == wrap_template_with_messages_loop(original_template).encode("utf-8")
|
||||
|
||||
|
||||
def test_local_snapshot_file_binary_download_and_digest(tmp_path):
|
||||
# Use a payload that includes a null byte to ensure we are truly treating this as binary data
|
||||
content = b"binary-\x00-test-content"
|
||||
expected_digest = generate_sha256_binary(content)
|
||||
|
||||
# Create a LocalSnapshotFile with known binary content
|
||||
snapshot_file = LocalSnapshotFile(
|
||||
name="test.bin",
|
||||
type=SnapshotFileType.Other,
|
||||
content=content,
|
||||
)
|
||||
|
||||
# Act: download to a temporary path
|
||||
target_path = tmp_path / "downloaded.bin"
|
||||
snapshot_file.download(target_path, tmp_path)
|
||||
|
||||
# Assert: file contents are exactly what we provided
|
||||
on_disk = target_path.read_bytes()
|
||||
assert on_disk == content
|
||||
|
||||
# Assert: digest matches generate_sha256_binary(content)
|
||||
assert snapshot_file.hash == expected_digest
|
||||
|
||||
@@ -39,7 +39,7 @@ class OllamaRepositoryMock(OllamaRepository):
|
||||
}
|
||||
|
||||
def get_file_list(self, tag, cached_files, is_model_in_ollama_cache, manifest=None) -> list[SnapshotFile]:
|
||||
return [LocalSnapshotFile("dummy content", "dummy", SnapshotFileType.Other)]
|
||||
return [LocalSnapshotFile("dummy content".encode("utf-8"), "dummy", SnapshotFileType.Other)]
|
||||
|
||||
|
||||
def test_ollama_model_pull(ollama_model):
|
||||
|
||||
@@ -210,14 +210,15 @@ class TestMLXRuntime:
|
||||
|
||||
@patch('ramalama.transports.base.platform.system')
|
||||
@patch('ramalama.transports.base.platform.machine')
|
||||
@patch('ramalama.transports.base.os.fork')
|
||||
@patch('ramalama.transports.base.Transport.serve_nonblocking', return_value=MagicMock())
|
||||
@patch('ramalama.chat.chat')
|
||||
@patch('socket.socket')
|
||||
def test_mlx_run_uses_server_client_model(self, mock_socket_class, mock_chat, mock_fork, mock_machine, mock_system):
|
||||
def test_mlx_run_uses_server_client_model(
|
||||
self, mock_socket_class, mock_chat, mock_serve_nonblocking, mock_machine, mock_system
|
||||
):
|
||||
"""Test that MLX runtime uses server-client model in run method"""
|
||||
mock_system.return_value = "Darwin"
|
||||
mock_machine.return_value = "arm64"
|
||||
mock_fork.return_value = 123 # Parent process
|
||||
|
||||
# Mock socket to simulate successful connection (server ready)
|
||||
mock_socket = MagicMock()
|
||||
@@ -249,16 +250,16 @@ class TestMLXRuntime:
|
||||
with patch('sys.stdin.isatty', return_value=True): # Mock tty for interactive mode
|
||||
model.run(args, cmd)
|
||||
|
||||
# Verify that fork was called (indicating server-client model)
|
||||
mock_fork.assert_called_once()
|
||||
|
||||
# Verify that chat.chat was called (parent process)
|
||||
mock_chat.assert_called_once()
|
||||
|
||||
# Verify that serve_nonblocking was called (indicating server-client model) and that the server_process is set
|
||||
mock_serve_nonblocking.assert_called_once()
|
||||
assert mock_chat.call_args[0][0].server_process == mock_serve_nonblocking.return_value
|
||||
|
||||
# Verify args were set up correctly for server-client model
|
||||
# MLX runtime uses OpenAI-compatible endpoints under /v1
|
||||
assert args.url == "http://127.0.0.1:8080/v1"
|
||||
assert args.pid2kill == 123
|
||||
|
||||
|
||||
class TestOCIModelSetupMounts:
|
||||
|
||||
Reference in New Issue
Block a user