mirror of
https://github.com/containers/ramalama.git
synced 2026-02-05 15:47:26 +01:00
Resolve review comments
Signed-off-by: Oliver Walsh <owalsh@redhat.com>
This commit is contained in:
@@ -172,24 +172,10 @@ def get_parser():
|
||||
|
||||
def init_cli():
|
||||
"""Initialize the RamaLama CLI and parse command line arguments."""
|
||||
# Need to know if we're running with --dryrun or --generate before adding the subcommands,
|
||||
# otherwise calls to accel_image() when setting option defaults will cause unnecessary image pulls.
|
||||
if any(arg in ("--dryrun", "--dry-run", "--generate") or arg.startswith("--generate=") for arg in sys.argv[1:]):
|
||||
CONFIG.dryrun = True
|
||||
# Phase 1: Parse the initial arguments to set CONFIG.runtime etc... as this can affect the subcommands
|
||||
initial_parser = get_initial_parser()
|
||||
initial_args, _ = initial_parser.parse_known_args()
|
||||
for arg in initial_args.__dict__.keys():
|
||||
if hasattr(CONFIG, arg):
|
||||
setattr(CONFIG, arg, getattr(initial_args, arg))
|
||||
# Phase 2: Re-parse the arguments with the subcommands enabled
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
post_parse_setup(args)
|
||||
return parser, args
|
||||
return parse_args_from_cmd(sys.argv[1:])
|
||||
|
||||
|
||||
def parse_args_from_cmd(cmd: list[str]) -> argparse.Namespace:
|
||||
def parse_args_from_cmd(cmd: list[str]) -> tuple[argparse.ArgumentParser, argparse.Namespace]:
|
||||
"""Parse arguments based on a command string"""
|
||||
# Need to know if we're running with --dryrun or --generate before adding the subcommands,
|
||||
# otherwise calls to accel_image() when setting option defaults will cause unnecessary image pulls.
|
||||
@@ -205,7 +191,7 @@ def parse_args_from_cmd(cmd: list[str]) -> argparse.Namespace:
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
post_parse_setup(args)
|
||||
return args
|
||||
return parser, args
|
||||
|
||||
|
||||
def get_description():
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -58,7 +59,7 @@ class CommandFactory:
|
||||
# FIXME: binary should be a string array to work with nocontainer
|
||||
binary = CommandFactory.eval_stmt(engine.binary, ctx)
|
||||
if is_truthy(binary):
|
||||
cmd += binary.split(" ")
|
||||
cmd += shlex.split(binary)
|
||||
else:
|
||||
cmd.append(ContainerEntryPoint())
|
||||
|
||||
|
||||
@@ -765,6 +765,12 @@ def attempt_to_use_versioned(conman: str, image: str, vers: str, quiet: bool, sh
|
||||
return False
|
||||
|
||||
|
||||
class ContainerEntryPoint:
|
||||
class ContainerEntryPoint(str):
|
||||
def __init__(self, entrypoint: Optional[str] = None):
|
||||
self.entrypoint = entrypoint
|
||||
|
||||
def __str__(self):
|
||||
return str(self.entrypoint)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.entrypoint)
|
||||
|
||||
@@ -125,7 +125,7 @@ class DaemonAPIHandler(APIHandler):
|
||||
ramalama_cmd = ["ramalama", "--runtime", serve_request.runtime, "serve", model.model_name, "--port", str(port)]
|
||||
for arg, val in serve_request.exec_args.items():
|
||||
ramalama_cmd.extend([arg, val])
|
||||
args = parse_args_from_cmd(ramalama_cmd)
|
||||
_, args = parse_args_from_cmd(ramalama_cmd)
|
||||
# always log to file at the model-specific location
|
||||
args.logfile = f"{DEFAULT_LOG_DIR}/{model.model_organization}_{model.model_name}_{model.model_tag}.log"
|
||||
inference_engine_command = assemble_command(args)
|
||||
|
||||
@@ -441,7 +441,7 @@ def is_healthy(args, timeout: int = 3, model_name: str | None = None):
|
||||
conn = None
|
||||
try:
|
||||
conn = HTTPConnection("127.0.0.1", args.port, timeout=timeout)
|
||||
if args.debug:
|
||||
if getattr(args, "debug", False):
|
||||
conn.set_debuglevel(1)
|
||||
if CONFIG.runtime == 'vllm':
|
||||
conn.request("GET", "/ping")
|
||||
@@ -491,13 +491,14 @@ def wait_for_healthy(args, health_func: Callable[[Any], bool], timeout=None):
|
||||
logger.debug(f"Waiting for container {args.name} to become healthy (timeout: {timeout}s)...")
|
||||
start_time = time.time()
|
||||
|
||||
display_dots = not getattr(args, "debug", False) and sys.stdin.isatty()
|
||||
n = 0
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
if not args.debug:
|
||||
if display_dots:
|
||||
perror('\r' + n * '.', end='', flush=True)
|
||||
if health_func(args):
|
||||
if not args.debug:
|
||||
if display_dots:
|
||||
perror('\r' + n * ' ' + '\r', end='', flush=True)
|
||||
return
|
||||
except (ConnectionError, HTTPException, UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||
|
||||
@@ -133,6 +133,7 @@ EOF
|
||||
}
|
||||
|
||||
@test "ramalama verify default engine" {
|
||||
# Cannot run on docker as that also sets RAMALAMA_CONTAINER_ENGINE
|
||||
skip_if_docker
|
||||
engine=e_$(safename)
|
||||
RAMALAMA_CONTAINER_ENGINE=${engine} run_ramalama --help
|
||||
|
||||
@@ -7,7 +7,7 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ramalama.engine import Engine, containers, dry_run, images, is_healthy, wait_for_healthy
|
||||
import ramalama.engine
|
||||
|
||||
|
||||
class TestEngine(unittest.TestCase):
|
||||
@@ -23,13 +23,13 @@ class TestEngine(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_init_basic(self):
|
||||
engine = Engine(self.base_args)
|
||||
engine = ramalama.engine.Engine(self.base_args)
|
||||
self.assertEqual(engine.use_podman, True)
|
||||
self.assertEqual(engine.use_docker, False)
|
||||
|
||||
def test_add_container_labels(self):
|
||||
args = Namespace(**vars(self.base_args), MODEL="test-model", port="8080", subcommand="run")
|
||||
engine = Engine(args)
|
||||
engine = ramalama.engine.Engine(args)
|
||||
exec_args = engine.exec_args
|
||||
self.assertNotIn("--rm", exec_args)
|
||||
self.assertIn("--label", exec_args)
|
||||
@@ -39,7 +39,7 @@ class TestEngine(unittest.TestCase):
|
||||
|
||||
def test_serve_rm(self):
|
||||
args = Namespace(**vars(self.base_args), MODEL="test-model", port="8080", subcommand="serve")
|
||||
engine = Engine(args)
|
||||
engine = ramalama.engine.Engine(args)
|
||||
exec_args = engine.exec_args
|
||||
self.assertIn("--rm", exec_args)
|
||||
|
||||
@@ -50,14 +50,14 @@ class TestEngine(unittest.TestCase):
|
||||
mock_os_access.return_value = True
|
||||
|
||||
# Test Podman
|
||||
podman_engine = Engine(self.base_args)
|
||||
podman_engine = ramalama.engine.Engine(self.base_args)
|
||||
self.assertIn("--runtime", podman_engine.exec_args)
|
||||
self.assertIn("/usr/bin/nvidia-container-runtime", podman_engine.exec_args)
|
||||
|
||||
# Test Podman when nvidia-container-runtime executable is missing
|
||||
# This is expected with the official package
|
||||
mock_os_access.return_value = False
|
||||
podman_engine = Engine(self.base_args)
|
||||
podman_engine = ramalama.engine.Engine(self.base_args)
|
||||
self.assertNotIn("--runtime", podman_engine.exec_args)
|
||||
self.assertNotIn("/usr/bin/nvidia-container-runtime", podman_engine.exec_args)
|
||||
|
||||
@@ -65,30 +65,30 @@ class TestEngine(unittest.TestCase):
|
||||
args = self.base_args
|
||||
args.engine = "docker"
|
||||
docker_args = Namespace(**vars(args))
|
||||
docker_engine = Engine(docker_args)
|
||||
docker_engine = ramalama.engine.Engine(docker_args)
|
||||
self.assertIn("--runtime", docker_engine.exec_args)
|
||||
self.assertIn("nvidia", docker_engine.exec_args)
|
||||
|
||||
def test_add_privileged_options(self):
|
||||
# Test non-privileged (default)
|
||||
engine = Engine(self.base_args)
|
||||
engine = ramalama.engine.Engine(self.base_args)
|
||||
self.assertIn("--security-opt=label=disable", engine.exec_args)
|
||||
self.assertIn("--cap-drop=all", engine.exec_args)
|
||||
|
||||
# Test privileged
|
||||
privileged_args = Namespace(**vars(self.base_args), privileged=True)
|
||||
privileged_engine = Engine(privileged_args)
|
||||
privileged_engine = ramalama.engine.Engine(privileged_args)
|
||||
self.assertIn("--privileged", privileged_engine.exec_args)
|
||||
|
||||
def test_add_selinux(self):
|
||||
self.base_args.selinux = True
|
||||
# Test non-privileged (default)
|
||||
engine = Engine(self.base_args)
|
||||
engine = ramalama.engine.Engine(self.base_args)
|
||||
self.assertNotIn("--security-opt=label=disable", engine.exec_args)
|
||||
|
||||
def test_add_port_option(self):
|
||||
args = Namespace(**vars(self.base_args), port="8080")
|
||||
engine = Engine(args)
|
||||
engine = ramalama.engine.Engine(args)
|
||||
self.assertIn("-p", engine.exec_args)
|
||||
self.assertIn("8080:8080", engine.exec_args)
|
||||
|
||||
@@ -96,7 +96,7 @@ class TestEngine(unittest.TestCase):
|
||||
def test_images(self, mock_run_cmd):
|
||||
mock_run_cmd.return_value.stdout = b"image1\nimage2\n"
|
||||
args = Namespace(engine="podman", debug=False, format="", noheading=False, notrunc=False)
|
||||
result = images(args)
|
||||
result = ramalama.engine.images(args)
|
||||
self.assertEqual(result, ["image1", "image2"])
|
||||
mock_run_cmd.assert_called_once()
|
||||
|
||||
@@ -104,20 +104,20 @@ class TestEngine(unittest.TestCase):
|
||||
def test_containers(self, mock_run_cmd):
|
||||
mock_run_cmd.return_value.stdout = b"container1\ncontainer2\n"
|
||||
args = Namespace(engine="podman", debug=False, format="", noheading=False, notrunc=False)
|
||||
result = containers(args)
|
||||
result = ramalama.engine.containers(args)
|
||||
self.assertEqual(result, ["container1", "container2"])
|
||||
mock_run_cmd.assert_called_once()
|
||||
|
||||
def test_dry_run(self):
|
||||
with patch('sys.stdout') as mock_stdout:
|
||||
dry_run(["podman", "run", "--rm", "test-image"])
|
||||
ramalama.engine.dry_run(["podman", "run", "--rm", "test-image"])
|
||||
mock_stdout.write.assert_called()
|
||||
|
||||
|
||||
@patch("ramalama.engine.HTTPConnection")
|
||||
def test_is_healthy_conn(mock_conn):
|
||||
args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False)
|
||||
is_healthy(args, model_name="themodel")
|
||||
ramalama.engine.is_healthy(args, model_name="themodel")
|
||||
mock_conn.assert_called_once_with("127.0.0.1", args.port, timeout=3)
|
||||
|
||||
|
||||
@@ -152,7 +152,7 @@ def test_is_healthy_fail(mock_conn, mock_debug, health_status, models_status, mo
|
||||
responses.append(mock_models_resp)
|
||||
mock_conn.return_value.getresponse.side_effect = responses
|
||||
args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False)
|
||||
assert not is_healthy(args, model_name="themodel")
|
||||
assert not ramalama.engine.is_healthy(args, model_name="themodel")
|
||||
assert mock_conn.return_value.getresponse.call_count == len(responses)
|
||||
if len(responses) > 1:
|
||||
assert models_msg in mock_debug.call_args.args[0]
|
||||
@@ -166,7 +166,7 @@ def test_is_healthy_unicode_fail(mock_conn):
|
||||
mock_conn.return_value.getresponse.side_effect = [mock_health_resp, mock_models_resp]
|
||||
args = Namespace(name="thecontainer", port=8080, debug=False)
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
is_healthy(args, model_name="themodel")
|
||||
ramalama.engine.is_healthy(args, model_name="themodel")
|
||||
assert mock_conn.return_value.getresponse.call_count == 2
|
||||
|
||||
|
||||
@@ -185,11 +185,29 @@ def test_is_healthy_success(mock_conn, mock_debug, health_status):
|
||||
mock_models_resp.read.return_value = '{"models": [{"name": "themodel"}]}'
|
||||
mock_conn.return_value.getresponse.side_effect = [mock_health_resp, mock_models_resp]
|
||||
args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False)
|
||||
assert is_healthy(args, model_name="themodel")
|
||||
assert ramalama.engine.is_healthy(args, model_name="themodel")
|
||||
assert mock_conn.return_value.getresponse.call_count == 2
|
||||
assert mock_debug.call_args.args[0] == "Container thecontainer is healthy"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status, ok",
|
||||
[
|
||||
(500, False),
|
||||
(404, False),
|
||||
(200, True),
|
||||
],
|
||||
)
|
||||
@patch("ramalama.engine.HTTPConnection")
|
||||
def test_is_healthy_vllm(mock_conn, status, ok):
|
||||
mock_resp = mock_conn.return_value.getresponse.return_value
|
||||
mock_resp.status = status
|
||||
args = Namespace(MODEL="themodel", name="thecontainer", port=8080, debug=False, runtime="vllm")
|
||||
with patch.object(ramalama.engine.CONFIG, "runtime", "vllm"):
|
||||
assert ramalama.engine.is_healthy(args) == ok
|
||||
assert mock_conn.return_value.mock_calls[0].args == ('GET', '/ping')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exc",
|
||||
[
|
||||
@@ -207,7 +225,7 @@ def test_wait_for_healthy_error(mock_logs, exc):
|
||||
|
||||
args = Namespace(name="thecontainer", debug=True, engine="podman")
|
||||
with pytest.raises(TimeoutExpired):
|
||||
wait_for_healthy(args, healthy_func, timeout=1)
|
||||
ramalama.engine.wait_for_healthy(args, healthy_func, timeout=1)
|
||||
mock_logs.assert_called_once()
|
||||
|
||||
|
||||
@@ -219,7 +237,7 @@ def test_wait_for_healthy_timeout(mock_logs):
|
||||
|
||||
args = Namespace(name="thecontainer", debug=True, engine="podman")
|
||||
with pytest.raises(TimeoutExpired, match="timed out after 0 seconds"):
|
||||
wait_for_healthy(args, healthy_func, timeout=0)
|
||||
ramalama.engine.wait_for_healthy(args, healthy_func, timeout=0)
|
||||
mock_logs.assert_called_once()
|
||||
|
||||
|
||||
@@ -229,7 +247,7 @@ def test_wait_for_healthy_success():
|
||||
return True
|
||||
|
||||
args = Namespace(name="thecontainer", debug=False)
|
||||
wait_for_healthy(args, healthy_func, timeout=1)
|
||||
ramalama.engine.wait_for_healthy(args, healthy_func, timeout=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user