1
0
mirror of https://github.com/containers/ramalama.git synced 2026-02-05 15:47:26 +01:00

Resolve review comments

Signed-off-by: Oliver Walsh <owalsh@redhat.com>
This commit is contained in:
Oliver Walsh
2026-01-14 15:28:26 +00:00
parent aa7a271c73
commit f37ec5cc88
7 changed files with 57 additions and 44 deletions

View File

@@ -172,24 +172,10 @@ def get_parser():
def init_cli():
"""Initialize the RamaLama CLI and parse command line arguments."""
# Need to know if we're running with --dryrun or --generate before adding the subcommands,
# otherwise calls to accel_image() when setting option defaults will cause unnecessary image pulls.
if any(arg in ("--dryrun", "--dry-run", "--generate") or arg.startswith("--generate=") for arg in sys.argv[1:]):
CONFIG.dryrun = True
# Phase 1: Parse the initial arguments to set CONFIG.runtime etc... as this can affect the subcommands
initial_parser = get_initial_parser()
initial_args, _ = initial_parser.parse_known_args()
for arg in initial_args.__dict__.keys():
if hasattr(CONFIG, arg):
setattr(CONFIG, arg, getattr(initial_args, arg))
# Phase 2: Re-parse the arguments with the subcommands enabled
parser = get_parser()
args = parser.parse_args()
post_parse_setup(args)
return parser, args
return parse_args_from_cmd(sys.argv[1:])
def parse_args_from_cmd(cmd: list[str]) -> argparse.Namespace:
def parse_args_from_cmd(cmd: list[str]) -> tuple[argparse.ArgumentParser, argparse.Namespace]:
"""Parse arguments based on a command string"""
# Need to know if we're running with --dryrun or --generate before adding the subcommands,
# otherwise calls to accel_image() when setting option defaults will cause unnecessary image pulls.
@@ -205,7 +191,7 @@ def parse_args_from_cmd(cmd: list[str]) -> argparse.Namespace:
parser = get_parser()
args = parser.parse_args(cmd)
post_parse_setup(args)
return args
return parser, args
def get_description():

View File

@@ -1,6 +1,7 @@
import argparse
import ast
import json
import shlex
from pathlib import Path
from typing import Any
@@ -58,7 +59,7 @@ class CommandFactory:
# FIXME: binary should be a string array to work with nocontainer
binary = CommandFactory.eval_stmt(engine.binary, ctx)
if is_truthy(binary):
cmd += binary.split(" ")
cmd += shlex.split(binary)
else:
cmd.append(ContainerEntryPoint())

View File

@@ -765,6 +765,12 @@ def attempt_to_use_versioned(conman: str, image: str, vers: str, quiet: bool, sh
return False
class ContainerEntryPoint:
class ContainerEntryPoint(str):
def __init__(self, entrypoint: Optional[str] = None):
self.entrypoint = entrypoint
def __str__(self):
return str(self.entrypoint)
def __repr__(self):
return repr(self.entrypoint)

View File

@@ -125,7 +125,7 @@ class DaemonAPIHandler(APIHandler):
ramalama_cmd = ["ramalama", "--runtime", serve_request.runtime, "serve", model.model_name, "--port", str(port)]
for arg, val in serve_request.exec_args.items():
ramalama_cmd.extend([arg, val])
args = parse_args_from_cmd(ramalama_cmd)
_, args = parse_args_from_cmd(ramalama_cmd)
# always log to file at the model-specific location
args.logfile = f"{DEFAULT_LOG_DIR}/{model.model_organization}_{model.model_name}_{model.model_tag}.log"
inference_engine_command = assemble_command(args)

View File

@@ -441,7 +441,7 @@ def is_healthy(args, timeout: int = 3, model_name: str | None = None):
conn = None
try:
conn = HTTPConnection("127.0.0.1", args.port, timeout=timeout)
if args.debug:
if getattr(args, "debug", False):
conn.set_debuglevel(1)
if CONFIG.runtime == 'vllm':
conn.request("GET", "/ping")
@@ -491,13 +491,14 @@ def wait_for_healthy(args, health_func: Callable[[Any], bool], timeout=None):
logger.debug(f"Waiting for container {args.name} to become healthy (timeout: {timeout}s)...")
start_time = time.time()
display_dots = not getattr(args, "debug", False) and sys.stdin.isatty()
n = 0
while time.time() - start_time < timeout:
try:
if not args.debug:
if display_dots:
perror('\r' + n * '.', end='', flush=True)
if health_func(args):
if not args.debug:
if display_dots:
perror('\r' + n * ' ' + '\r', end='', flush=True)
return
except (ConnectionError, HTTPException, UnicodeDecodeError, json.JSONDecodeError) as e:

View File

@@ -133,6 +133,7 @@ EOF
}
@test "ramalama verify default engine" {
# Cannot run on docker as that also sets RAMALAMA_CONTAINER_ENGINE
skip_if_docker
engine=e_$(safename)
RAMALAMA_CONTAINER_ENGINE=${engine} run_ramalama --help

View File

@@ -7,7 +7,7 @@ from unittest.mock import Mock, patch
import pytest
from ramalama.engine import Engine, containers, dry_run, images, is_healthy, wait_for_healthy
import ramalama.engine
class TestEngine(unittest.TestCase):
@@ -23,13 +23,13 @@ class TestEngine(unittest.TestCase):
)
def test_init_basic(self):
engine = Engine(self.base_args)
engine = ramalama.engine.Engine(self.base_args)
self.assertEqual(engine.use_podman, True)
self.assertEqual(engine.use_docker, False)
def test_add_container_labels(self):
args = Namespace(**vars(self.base_args), MODEL="test-model", port="8080", subcommand="run")
engine = Engine(args)
engine = ramalama.engine.Engine(args)
exec_args = engine.exec_args
self.assertNotIn("--rm", exec_args)
self.assertIn("--label", exec_args)
@@ -39,7 +39,7 @@ class TestEngine(unittest.TestCase):
def test_serve_rm(self):
args = Namespace(**vars(self.base_args), MODEL="test-model", port="8080", subcommand="serve")
engine = Engine(args)
engine = ramalama.engine.Engine(args)
exec_args = engine.exec_args
self.assertIn("--rm", exec_args)
@@ -50,14 +50,14 @@ class TestEngine(unittest.TestCase):
mock_os_access.return_value = True
# Test Podman
podman_engine = Engine(self.base_args)
podman_engine = ramalama.engine.Engine(self.base_args)
self.assertIn("--runtime", podman_engine.exec_args)
self.assertIn("/usr/bin/nvidia-container-runtime", podman_engine.exec_args)
# Test Podman when nvidia-container-runtime executable is missing
# This is expected with the official package
mock_os_access.return_value = False
podman_engine = Engine(self.base_args)
podman_engine = ramalama.engine.Engine(self.base_args)
self.assertNotIn("--runtime", podman_engine.exec_args)
self.assertNotIn("/usr/bin/nvidia-container-runtime", podman_engine.exec_args)
@@ -65,30 +65,30 @@ class TestEngine(unittest.TestCase):
args = self.base_args
args.engine = "docker"
docker_args = Namespace(**vars(args))
docker_engine = Engine(docker_args)
docker_engine = ramalama.engine.Engine(docker_args)
self.assertIn("--runtime", docker_engine.exec_args)
self.assertIn("nvidia", docker_engine.exec_args)
def test_add_privileged_options(self):
# Test non-privileged (default)
engine = Engine(self.base_args)
engine = ramalama.engine.Engine(self.base_args)
self.assertIn("--security-opt=label=disable", engine.exec_args)
self.assertIn("--cap-drop=all", engine.exec_args)
# Test privileged
privileged_args = Namespace(**vars(self.base_args), privileged=True)
privileged_engine = Engine(privileged_args)
privileged_engine = ramalama.engine.Engine(privileged_args)
self.assertIn("--privileged", privileged_engine.exec_args)
def test_add_selinux(self):
self.base_args.selinux = True
# Test non-privileged (default)
engine = Engine(self.base_args)
engine = ramalama.engine.Engine(self.base_args)
self.assertNotIn("--security-opt=label=disable", engine.exec_args)
def test_add_port_option(self):
args = Namespace(**vars(self.base_args), port="8080")
engine = Engine(args)
engine = ramalama.engine.Engine(args)
self.assertIn("-p", engine.exec_args)
self.assertIn("8080:8080", engine.exec_args)
@@ -96,7 +96,7 @@ class TestEngine(unittest.TestCase):
def test_images(self, mock_run_cmd):
mock_run_cmd.return_value.stdout = b"image1\nimage2\n"
args = Namespace(engine="podman", debug=False, format="", noheading=False, notrunc=False)
result = images(args)
result = ramalama.engine.images(args)
self.assertEqual(result, ["image1", "image2"])
mock_run_cmd.assert_called_once()
@@ -104,20 +104,20 @@ class TestEngine(unittest.TestCase):
def test_containers(self, mock_run_cmd):
mock_run_cmd.return_value.stdout = b"container1\ncontainer2\n"
args = Namespace(engine="podman", debug=False, format="", noheading=False, notrunc=False)
result = containers(args)
result = ramalama.engine.containers(args)
self.assertEqual(result, ["container1", "container2"])
mock_run_cmd.assert_called_once()
def test_dry_run(self):
with patch('sys.stdout') as mock_stdout:
dry_run(["podman", "run", "--rm", "test-image"])
ramalama.engine.dry_run(["podman", "run", "--rm", "test-image"])
mock_stdout.write.assert_called()
@patch("ramalama.engine.HTTPConnection")
def test_is_healthy_conn(mock_conn):
args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False)
is_healthy(args, model_name="themodel")
ramalama.engine.is_healthy(args, model_name="themodel")
mock_conn.assert_called_once_with("127.0.0.1", args.port, timeout=3)
@@ -152,7 +152,7 @@ def test_is_healthy_fail(mock_conn, mock_debug, health_status, models_status, mo
responses.append(mock_models_resp)
mock_conn.return_value.getresponse.side_effect = responses
args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False)
assert not is_healthy(args, model_name="themodel")
assert not ramalama.engine.is_healthy(args, model_name="themodel")
assert mock_conn.return_value.getresponse.call_count == len(responses)
if len(responses) > 1:
assert models_msg in mock_debug.call_args.args[0]
@@ -166,7 +166,7 @@ def test_is_healthy_unicode_fail(mock_conn):
mock_conn.return_value.getresponse.side_effect = [mock_health_resp, mock_models_resp]
args = Namespace(name="thecontainer", port=8080, debug=False)
with pytest.raises(UnicodeDecodeError):
is_healthy(args, model_name="themodel")
ramalama.engine.is_healthy(args, model_name="themodel")
assert mock_conn.return_value.getresponse.call_count == 2
@@ -185,11 +185,29 @@ def test_is_healthy_success(mock_conn, mock_debug, health_status):
mock_models_resp.read.return_value = '{"models": [{"name": "themodel"}]}'
mock_conn.return_value.getresponse.side_effect = [mock_health_resp, mock_models_resp]
args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False)
assert is_healthy(args, model_name="themodel")
assert ramalama.engine.is_healthy(args, model_name="themodel")
assert mock_conn.return_value.getresponse.call_count == 2
assert mock_debug.call_args.args[0] == "Container thecontainer is healthy"
@pytest.mark.parametrize(
"status, ok",
[
(500, False),
(404, False),
(200, True),
],
)
@patch("ramalama.engine.HTTPConnection")
def test_is_healthy_vllm(mock_conn, status, ok):
mock_resp = mock_conn.return_value.getresponse.return_value
mock_resp.status = status
args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False, runtime="vllm")
with patch.object(ramalama.engine.CONFIG, "runtime", "vllm"):
assert ramalama.engine.is_healthy(args) == ok
assert mock_conn.return_value.mock_calls[0].args == ('GET', '/ping')
@pytest.mark.parametrize(
"exc",
[
@@ -207,7 +225,7 @@ def test_wait_for_healthy_error(mock_logs, exc):
args = Namespace(name="thecontainer", debug=True, engine="podman")
with pytest.raises(TimeoutExpired):
wait_for_healthy(args, healthy_func, timeout=1)
ramalama.engine.wait_for_healthy(args, healthy_func, timeout=1)
mock_logs.assert_called_once()
@@ -219,7 +237,7 @@ def test_wait_for_healthy_timeout(mock_logs):
args = Namespace(name="thecontainer", debug=True, engine="podman")
with pytest.raises(TimeoutExpired, match="timed out after 0 seconds"):
wait_for_healthy(args, healthy_func, timeout=0)
ramalama.engine.wait_for_healthy(args, healthy_func, timeout=0)
mock_logs.assert_called_once()
@@ -229,7 +247,7 @@ def test_wait_for_healthy_success():
return True
args = Namespace(name="thecontainer", debug=False)
wait_for_healthy(args, healthy_func, timeout=1)
ramalama.engine.wait_for_healthy(args, healthy_func, timeout=1)
if __name__ == '__main__':