1
0
mirror of https://github.com/containers/ramalama.git synced 2026-02-05 15:47:26 +01:00
Files
ramalama/test/unit/test_cli_args.py
2026-01-24 16:00:59 -06:00

131 lines
3.9 KiB
Python

import sys
from dataclasses import fields
from typing import get_args
from unittest.mock import MagicMock
import pytest
from ramalama.arg_types import ChatSubArgs, DefaultArgs
from ramalama.cli import get_parser
from ramalama.config import SUPPORTED_ENGINES
try:
from hypothesis import given
from hypothesis import strategies as st
HAS_HYPOTHESIS = True
def shell_quoted_string_with_escaping():
"""
Cli arguments can't be empty so we set the min_size to 1.
Additionally we quote all string arguments to avoid special character issues.
"""
base = st.text(min_size=1)
def quote(s):
if "'" not in s:
return f"'{s}'"
else:
return s.replace('"', '\\"')
return base.map(quote)
st.register_type_strategy(str, shell_quoted_string_with_escaping())
except ImportError:
HAS_HYPOTHESIS = False
hypothesis = MagicMock()
sys.modules["hypothesis"] = hypothesis
hypothesis.given = lambda *x, **y: lambda *z: z
hypothesis.strategies = MagicMock()
hypothesis.strategies.sampled_from = lambda *x, **y: x
hypothesis.strategies.just = lambda *x, **y: x
hypothesis.strategies.text = lambda *x, **y: x
hypothesis.strategies.builds = lambda *x, **y: x
hypothesis.strategies.register_type_strategy = lambda *x, **y: x
from hypothesis import given
from hypothesis import strategies as st
parser = get_parser()
special_cases = {
"api_key": "api-key",
"max_tokens": "max-tokens",
}
def args_to_cli_args(args_obj, subcommand: str | None, special_cases: dict | None = None) -> list:
"""
Convert a dataclass instance to CLI arguments for argparse.
- subcommand: the CLI subcommand (e.g., 'chat')
- special_cases: dict mapping attribute names to CLI flag names (e.g., {'api_key': 'api-key'})
"""
if special_cases is None:
special_cases = {}
cli_args = []
if subcommand is not None:
cli_args.append(subcommand)
for f in fields(args_obj):
if (value := getattr(args_obj, f.name)) is None:
continue
# Determine CLI flag name
flag = f"--{special_cases.get(f.name, f.name)}"
# Handle booleans as flags
if f.type is bool or (getattr(f.type, '__origin__', None) is type(None) and isinstance(value, bool)):
if value:
cli_args.append(flag)
continue
# TODO: Handle list as positional arguments. This is hacky, maybe introspect the parser for nargs?
if isinstance(value, list):
cli_args.extend(value)
continue
# Otherwise, add as --flag value
cli_args.extend([flag, str(value)])
return cli_args
@pytest.mark.skipif(not HAS_HYPOTHESIS, reason="Hypothesis is not installed")
@given(
st.builds(
DefaultArgs,
engine=st.sampled_from(get_args(SUPPORTED_ENGINES)),
store=st.sampled_from(['/', '/tmp']),
debug=st.just(False),
quiet=st.just(False),
)
)
def test_default_endpoint(chatargs):
cli_args = args_to_cli_args(chatargs, None, special_cases)
args = parser.parse_args(cli_args)
for field in DefaultArgs.__dataclass_fields__:
assert hasattr(args, field), f"Missing attribute: {field}"
@pytest.mark.skipif(not HAS_HYPOTHESIS, reason="Hypothesis is not installed")
@given(
st.builds(
ChatSubArgs,
prefix=st.sampled_from(['> ', '🦙 > ', '🦭 > ', '🐋 > ']),
url=st.sampled_from(['https://test.com', 'test.com']),
temp=st.one_of(
st.none(),
st.floats(min_value=0, allow_nan=False, allow_infinity=False).map(lambda v: 0.0 if v == 0 else v),
),
)
)
def test_chat_endpoint(chatargs):
cli_args = args_to_cli_args(chatargs, 'chat', special_cases)
args = parser.parse_args(cli_args)
for field in ChatSubArgs.__dataclass_fields__:
assert hasattr(args, field), f"Missing attribute: {field}"