diff --git a/ramalama/cli.py b/ramalama/cli.py index 7b024879..f6b42f1d 100644 --- a/ramalama/cli.py +++ b/ramalama/cli.py @@ -172,24 +172,10 @@ def get_parser(): def init_cli(): """Initialize the RamaLama CLI and parse command line arguments.""" - # Need to know if we're running with --dryrun or --generate before adding the subcommands, - # otherwise calls to accel_image() when setting option defaults will cause unnecessary image pulls. - if any(arg in ("--dryrun", "--dry-run", "--generate") or arg.startswith("--generate=") for arg in sys.argv[1:]): - CONFIG.dryrun = True - # Phase 1: Parse the initial arguments to set CONFIG.runtime etc... as this can affect the subcommands - initial_parser = get_initial_parser() - initial_args, _ = initial_parser.parse_known_args() - for arg in initial_args.__dict__.keys(): - if hasattr(CONFIG, arg): - setattr(CONFIG, arg, getattr(initial_args, arg)) - # Phase 2: Re-parse the arguments with the subcommands enabled - parser = get_parser() - args = parser.parse_args() - post_parse_setup(args) - return parser, args + return parse_args_from_cmd(sys.argv[1:]) -def parse_args_from_cmd(cmd: list[str]) -> argparse.Namespace: +def parse_args_from_cmd(cmd: list[str]) -> tuple[argparse.ArgumentParser, argparse.Namespace]: """Parse arguments based on a command string""" # Need to know if we're running with --dryrun or --generate before adding the subcommands, # otherwise calls to accel_image() when setting option defaults will cause unnecessary image pulls. @@ -205,7 +191,7 @@ def parse_args_from_cmd(cmd: list[str]) -> argparse.Namespace: parser = get_parser() args = parser.parse_args(cmd) post_parse_setup(args) - return args + return parser, args def get_description(): diff --git a/ramalama/command/factory.py b/ramalama/command/factory.py index 41b4fc4d..0b73c2f9 100644 --- a/ramalama/command/factory.py +++ b/ramalama/command/factory.py @@ -1,6 +1,7 @@ import argparse import ast import json +import shlex from pathlib import Path from typing import Any @@ -58,7 +59,7 @@ class CommandFactory: # FIXME: binary should be a string array to work with nocontainer binary = CommandFactory.eval_stmt(engine.binary, ctx) if is_truthy(binary): - cmd += binary.split(" ") + cmd += shlex.split(binary) else: cmd.append(ContainerEntryPoint()) diff --git a/ramalama/common.py b/ramalama/common.py index 284175de..19a009bc 100644 --- a/ramalama/common.py +++ b/ramalama/common.py @@ -765,6 +765,12 @@ def attempt_to_use_versioned(conman: str, image: str, vers: str, quiet: bool, sh return False -class ContainerEntryPoint: +class ContainerEntryPoint(str): def __init__(self, entrypoint: Optional[str] = None): self.entrypoint = entrypoint + + def __str__(self): + return str(self.entrypoint) + + def __repr__(self): + return repr(self.entrypoint) diff --git a/ramalama/daemon/handler/daemon.py b/ramalama/daemon/handler/daemon.py index 8483139a..48465fc2 100644 --- a/ramalama/daemon/handler/daemon.py +++ b/ramalama/daemon/handler/daemon.py @@ -125,7 +125,7 @@ class DaemonAPIHandler(APIHandler): ramalama_cmd = ["ramalama", "--runtime", serve_request.runtime, "serve", model.model_name, "--port", str(port)] for arg, val in serve_request.exec_args.items(): ramalama_cmd.extend([arg, val]) - args = parse_args_from_cmd(ramalama_cmd) + _, args = parse_args_from_cmd(ramalama_cmd) # always log to file at the model-specific location args.logfile = f"{DEFAULT_LOG_DIR}/{model.model_organization}_{model.model_name}_{model.model_tag}.log" inference_engine_command = assemble_command(args) diff --git a/ramalama/engine.py b/ramalama/engine.py index 9dcc5b24..3c04cf7a 100644 --- a/ramalama/engine.py +++ b/ramalama/engine.py @@ -441,7 +441,7 @@ def is_healthy(args, timeout: int = 3, model_name: str | None = None): conn = None try: conn = HTTPConnection("127.0.0.1", args.port, timeout=timeout) - if args.debug: + if getattr(args, "debug", False): conn.set_debuglevel(1) if CONFIG.runtime == 'vllm': conn.request("GET", "/ping") @@ -491,13 +491,14 @@ def wait_for_healthy(args, health_func: Callable[[Any], bool], timeout=None): logger.debug(f"Waiting for container {args.name} to become healthy (timeout: {timeout}s)...") start_time = time.time() + display_dots = not getattr(args, "debug", False) and sys.stdin.isatty() n = 0 while time.time() - start_time < timeout: try: - if not args.debug: + if display_dots: perror('\r' + n * '.', end='', flush=True) if health_func(args): - if not args.debug: + if display_dots: perror('\r' + n * ' ' + '\r', end='', flush=True) return except (ConnectionError, HTTPException, UnicodeDecodeError, json.JSONDecodeError) as e: diff --git a/test/system/015-help.bats b/test/system/015-help.bats index 82194dc7..9d487908 100644 --- a/test/system/015-help.bats +++ b/test/system/015-help.bats @@ -133,6 +133,7 @@ EOF } @test "ramalama verify default engine" { + # Cannot run on docker as that also sets RAMALAMA_CONTAINER_ENGINE skip_if_docker engine=e_$(safename) RAMALAMA_CONTAINER_ENGINE=${engine} run_ramalama --help diff --git a/test/unit/test_engine.py b/test/unit/test_engine.py index d4de6ea6..b4c36374 100644 --- a/test/unit/test_engine.py +++ b/test/unit/test_engine.py @@ -7,7 +7,7 @@ from unittest.mock import Mock, patch import pytest -from ramalama.engine import Engine, containers, dry_run, images, is_healthy, wait_for_healthy +import ramalama.engine class TestEngine(unittest.TestCase): @@ -23,13 +23,13 @@ class TestEngine(unittest.TestCase): ) def test_init_basic(self): - engine = Engine(self.base_args) + engine = ramalama.engine.Engine(self.base_args) self.assertEqual(engine.use_podman, True) self.assertEqual(engine.use_docker, False) def test_add_container_labels(self): args = Namespace(**vars(self.base_args), MODEL="test-model", port="8080", subcommand="run") - engine = Engine(args) + engine = ramalama.engine.Engine(args) exec_args = engine.exec_args self.assertNotIn("--rm", exec_args) self.assertIn("--label", exec_args) @@ -39,7 +39,7 @@ class TestEngine(unittest.TestCase): def test_serve_rm(self): args = Namespace(**vars(self.base_args), MODEL="test-model", port="8080", subcommand="serve") - engine = Engine(args) + engine = ramalama.engine.Engine(args) exec_args = engine.exec_args self.assertIn("--rm", exec_args) @@ -50,14 +50,14 @@ class TestEngine(unittest.TestCase): mock_os_access.return_value = True # Test Podman - podman_engine = Engine(self.base_args) + podman_engine = ramalama.engine.Engine(self.base_args) self.assertIn("--runtime", podman_engine.exec_args) self.assertIn("/usr/bin/nvidia-container-runtime", podman_engine.exec_args) # Test Podman when nvidia-container-runtime executable is missing # This is expected with the official package mock_os_access.return_value = False - podman_engine = Engine(self.base_args) + podman_engine = ramalama.engine.Engine(self.base_args) self.assertNotIn("--runtime", podman_engine.exec_args) self.assertNotIn("/usr/bin/nvidia-container-runtime", podman_engine.exec_args) @@ -65,30 +65,30 @@ class TestEngine(unittest.TestCase): args = self.base_args args.engine = "docker" docker_args = Namespace(**vars(args)) - docker_engine = Engine(docker_args) + docker_engine = ramalama.engine.Engine(docker_args) self.assertIn("--runtime", docker_engine.exec_args) self.assertIn("nvidia", docker_engine.exec_args) def test_add_privileged_options(self): # Test non-privileged (default) - engine = Engine(self.base_args) + engine = ramalama.engine.Engine(self.base_args) self.assertIn("--security-opt=label=disable", engine.exec_args) self.assertIn("--cap-drop=all", engine.exec_args) # Test privileged privileged_args = Namespace(**vars(self.base_args), privileged=True) - privileged_engine = Engine(privileged_args) + privileged_engine = ramalama.engine.Engine(privileged_args) self.assertIn("--privileged", privileged_engine.exec_args) def test_add_selinux(self): self.base_args.selinux = True # Test non-privileged (default) - engine = Engine(self.base_args) + engine = ramalama.engine.Engine(self.base_args) self.assertNotIn("--security-opt=label=disable", engine.exec_args) def test_add_port_option(self): args = Namespace(**vars(self.base_args), port="8080") - engine = Engine(args) + engine = ramalama.engine.Engine(args) self.assertIn("-p", engine.exec_args) self.assertIn("8080:8080", engine.exec_args) @@ -96,7 +96,7 @@ class TestEngine(unittest.TestCase): def test_images(self, mock_run_cmd): mock_run_cmd.return_value.stdout = b"image1\nimage2\n" args = Namespace(engine="podman", debug=False, format="", noheading=False, notrunc=False) - result = images(args) + result = ramalama.engine.images(args) self.assertEqual(result, ["image1", "image2"]) mock_run_cmd.assert_called_once() @@ -104,20 +104,20 @@ class TestEngine(unittest.TestCase): def test_containers(self, mock_run_cmd): mock_run_cmd.return_value.stdout = b"container1\ncontainer2\n" args = Namespace(engine="podman", debug=False, format="", noheading=False, notrunc=False) - result = containers(args) + result = ramalama.engine.containers(args) self.assertEqual(result, ["container1", "container2"]) mock_run_cmd.assert_called_once() def test_dry_run(self): with patch('sys.stdout') as mock_stdout: - dry_run(["podman", "run", "--rm", "test-image"]) + ramalama.engine.dry_run(["podman", "run", "--rm", "test-image"]) mock_stdout.write.assert_called() @patch("ramalama.engine.HTTPConnection") def test_is_healthy_conn(mock_conn): args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False) - is_healthy(args, model_name="themodel") + ramalama.engine.is_healthy(args, model_name="themodel") mock_conn.assert_called_once_with("127.0.0.1", args.port, timeout=3) @@ -152,7 +152,7 @@ def test_is_healthy_fail(mock_conn, mock_debug, health_status, models_status, mo responses.append(mock_models_resp) mock_conn.return_value.getresponse.side_effect = responses args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False) - assert not is_healthy(args, model_name="themodel") + assert not ramalama.engine.is_healthy(args, model_name="themodel") assert mock_conn.return_value.getresponse.call_count == len(responses) if len(responses) > 1: assert models_msg in mock_debug.call_args.args[0] @@ -166,7 +166,7 @@ def test_is_healthy_unicode_fail(mock_conn): mock_conn.return_value.getresponse.side_effect = [mock_health_resp, mock_models_resp] args = Namespace(name="thecontainer", port=8080, debug=False) with pytest.raises(UnicodeDecodeError): - is_healthy(args, model_name="themodel") + ramalama.engine.is_healthy(args, model_name="themodel") assert mock_conn.return_value.getresponse.call_count == 2 @@ -185,11 +185,29 @@ def test_is_healthy_success(mock_conn, mock_debug, health_status): mock_models_resp.read.return_value = '{"models": [{"name": "themodel"}]}' mock_conn.return_value.getresponse.side_effect = [mock_health_resp, mock_models_resp] args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False) - assert is_healthy(args, model_name="themodel") + assert ramalama.engine.is_healthy(args, model_name="themodel") assert mock_conn.return_value.getresponse.call_count == 2 assert mock_debug.call_args.args[0] == "Container thecontainer is healthy" +@pytest.mark.parametrize( + "status, ok", + [ + (500, False), + (404, False), + (200, True), + ], +) +@patch("ramalama.engine.HTTPConnection") +def test_is_healthy_vllm(mock_conn, status, ok): + mock_resp = mock_conn.return_value.getresponse.return_value + mock_resp.status = status + args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False, runtime="vllm") + with patch.object(ramalama.engine.CONFIG, "runtime", "vllm"): + assert ramalama.engine.is_healthy(args) == ok + assert mock_conn.return_value.mock_calls[0].args == ('GET', '/ping') + + @pytest.mark.parametrize( "exc", [ @@ -207,7 +225,7 @@ def test_wait_for_healthy_error(mock_logs, exc): args = Namespace(name="thecontainer", debug=True, engine="podman") with pytest.raises(TimeoutExpired): - wait_for_healthy(args, healthy_func, timeout=1) + ramalama.engine.wait_for_healthy(args, healthy_func, timeout=1) mock_logs.assert_called_once() @@ -219,7 +237,7 @@ def test_wait_for_healthy_timeout(mock_logs): args = Namespace(name="thecontainer", debug=True, engine="podman") with pytest.raises(TimeoutExpired, match="timed out after 0 seconds"): - wait_for_healthy(args, healthy_func, timeout=0) + ramalama.engine.wait_for_healthy(args, healthy_func, timeout=0) mock_logs.assert_called_once() @@ -229,7 +247,7 @@ def test_wait_for_healthy_success(): return True args = Namespace(name="thecontainer", debug=False) - wait_for_healthy(args, healthy_func, timeout=1) + ramalama.engine.wait_for_healthy(args, healthy_func, timeout=1) if __name__ == '__main__':