1
0
mirror of https://github.com/containers/ramalama.git synced 2026-02-05 06:46:39 +01:00
Files
ramalama/test/unit/command/test_context.py
Michael Engel 042c2b0903 Refactor code base to use pathlib.Path were possible
Signed-off-by: Michael Engel <mengel@redhat.com>
2026-01-15 14:35:21 +01:00

120 lines
3.8 KiB
Python

import argparse
from typing import Any
from unittest.mock import MagicMock
import pytest
from ramalama.command import context
@pytest.mark.parametrize(
"args_dict",
[
{},
{
"host": "192.168.178.1",
"port": 1337,
"thinking": False,
"context": 512,
"temp": 11,
"debug": True,
"webui": True,
"ngl": 44,
"threads": 8,
"logfile": "/var/tmp/ramalama.log",
"container": True,
"model_draft": "draft",
"seed": 12345,
"runtime_args": "--another-arg 44 --more-args",
"cache_reuse": 1024,
},
{
"host": "192.168.178.1",
"port": 1337,
"thinking": False,
"container": True,
"model_draft": "draft",
},
{
"doesntexist": "not added to context",
},
],
)
def test_ramalama_args_context(args_dict: dict[str, Any]):
# Since there can be differences in names between the cli args to context arg
# this mapping is used to identify them and use the right name for this test
RAMALAMA_ARGS_CONTEXT_MAPPING = {
"ctx_size": "context",
}
ctx = context.RamalamaArgsContext.from_argparse(argparse.Namespace(**args_dict))
for ctx_field, ctx_value in ctx.__dict__.items():
ctx_field = (
ctx_field if ctx_field not in RAMALAMA_ARGS_CONTEXT_MAPPING else RAMALAMA_ARGS_CONTEXT_MAPPING[ctx_field]
)
if ctx_field in args_dict:
expected_value = args_dict[ctx_field]
assert (
ctx_value == expected_value
), f"Field '{ctx_field}' expected to be '{expected_value}', but got '{ctx_value}'"
else:
assert ctx_value is None, f"Field '{ctx_field}' expected to be None in args context"
@pytest.mark.parametrize(
"is_container, should_generate, dry_run",
[
(True, True, False),
(False, True, True),
(True, False, True),
(False, False, False),
],
)
def test_ramalama_model_context_properties(is_container, should_generate, dry_run):
mock_model = MagicMock()
mock_model.model_name = "smollm"
mock_model.model_tag = "135m"
mock_model.model_organization = "mock-org"
mock_model._get_entry_model_path.return_value = "/path/to/model"
mock_model._get_mmproj_path.return_value = "/path/to/mmproj"
mock_model._get_chat_template_path.return_value = "/path/to/chat-template"
mock_draft_model = MagicMock()
mock_draft_model._get_entry_model_path.return_value = "/path/to/draft-model"
mock_model.draft_model = mock_draft_model
ctx = context.RamalamaModelContext(
model=mock_model,
is_container=is_container,
should_generate=should_generate,
dry_run=dry_run,
)
assert ctx.name == "smollm:135m"
assert ctx.alias == "mock-org/smollm"
assert ctx.model_path == "/path/to/model"
assert ctx.mmproj_path == "/path/to/mmproj"
assert ctx.chat_template_path == "/path/to/chat-template"
assert ctx.draft_model_path == "/path/to/draft-model"
mock_model._get_entry_model_path.assert_called_with(is_container, should_generate, dry_run)
mock_model._get_mmproj_path.assert_called_with(is_container, should_generate, dry_run)
mock_model._get_chat_template_path.assert_called_with(is_container, should_generate, dry_run)
mock_draft_model._get_entry_model_path.assert_called_with(is_container, should_generate, dry_run)
def test_ramalama_model_context_without_draft_model():
# Pass in a dummy model which does not have the draft_model attribute
class DummyModel:
pass
ctx = context.RamalamaModelContext(
model=DummyModel(),
is_container=False,
should_generate=False,
dry_run=True,
)
assert ctx.draft_model_path is None