1
0
mirror of https://github.com/containers/ramalama.git synced 2026-02-05 06:46:39 +01:00
Files
ramalama/test/unit/test_transport_factory.py
2026-01-24 16:00:59 -06:00

190 lines
7.1 KiB
Python

from dataclasses import dataclass
from typing import Union
import pytest
import ramalama.transports.transport_factory as transport_factory_module
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
from ramalama.transports.api import APITransport
from ramalama.transports.huggingface import Huggingface
from ramalama.transports.modelscope import ModelScope
from ramalama.transports.oci import OCI
from ramalama.transports.ollama import Ollama
from ramalama.transports.rlcr import RamalamaContainerRegistry
from ramalama.transports.transport_factory import TransportFactory
from ramalama.transports.url import URL
@dataclass
class Input:
Model: str
Transport: str
Engine: str
class ARGS:
store = "/tmp/store"
engine = ""
container = True
def __init__(self, engine=""):
self.engine = engine
hf_granite_blob = "https://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob"
@pytest.mark.parametrize(
"input,expected,error",
[
(Input("", "", ""), None, KeyError),
(Input("openai://gpt-4o-mini", "", ""), APITransport, None),
(Input("huggingface://granite-code", "", ""), Huggingface, None),
(Input("hf://granite-code", "", ""), Huggingface, None),
(Input("hf.co/granite-code", "", ""), Huggingface, None),
(Input("modelscope://granite-code", "", ""), ModelScope, None),
(Input("ms://granite-code", "", ""), ModelScope, None),
(Input("ollama://granite-code", "", ""), Ollama, None),
(Input("ollama.com/library/granite-code", "", ""), Ollama, None),
(Input("oci://granite-code", "", "podman"), OCI, None),
(Input("docker://granite-code", "", "podman"), OCI, None),
(Input("rlcr://granite-code", "", "podman"), RamalamaContainerRegistry, None),
(
Input(
f"{hf_granite_blob}/main/granite-3b-code-base.Q4_K_M.gguf",
"",
"",
),
URL,
None,
),
(
Input(
f"{hf_granite_blob}/main/granite-3b-code-base.Q4_K_M.gguf",
"",
"",
),
URL,
None,
),
(Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None),
(Input("granite-code", "huggingface", ""), Huggingface, None),
(Input("granite-code", "ollama", ""), Ollama, None),
(Input("granite-code", "oci", ""), OCI, ValueError),
],
)
def test_model_factory_create(input: Input, expected: type[Union[Huggingface, Ollama, URL, OCI]], error):
args = ARGS(input.Engine)
if error is not None:
with pytest.raises(error):
TransportFactory(input.Model, args, input.Transport).create()
else:
model = TransportFactory(input.Model, args, input.Transport).create()
assert isinstance(model, expected)
@pytest.mark.parametrize(
"input,error",
[
(Input("", "", ""), KeyError),
(Input("oci://granite-code", "", "podman"), None),
(Input("docker://granite-code", "", "podman"), None),
(Input("rlcr://granite-code", "", "podman"), None),
(Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), ValueError),
(Input("huggingface://granite-code", "", ""), ValueError),
(Input("hf://granite-code", "", ""), ValueError),
(Input("hf.co/granite-code", "", ""), None),
(Input("modelscope://granite-code", "", ""), ValueError),
(Input("ms://granite-code", "", ""), ValueError),
(Input("ollama://granite-code", "", ""), ValueError),
(Input("ollama.com/library/granite-code", "", ""), None),
(Input("granite-code", "ollama", ""), None),
(Input("granite-code", "", ""), KeyError),
],
)
def test_validate_oci_model_input(input: Input, error):
args = ARGS(input.Engine)
if error is not None:
with pytest.raises(error):
TransportFactory(input.Model, args, input.Transport).validate_oci_model_input()
return
TransportFactory(input.Model, args, input.Transport).validate_oci_model_input()
@pytest.mark.parametrize(
"input,expected",
[
(Input("openai://gpt-4o-mini", "", ""), "gpt-4o-mini"),
(Input("huggingface://granite-code", "", ""), "granite-code"),
(
Input("huggingface://ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""),
"ibm-granite/granite-3b-code-base-2k-GGUF/granite-code",
),
(Input("hf://granite-code", "", ""), "granite-code"),
(Input("hf.co/granite-code", "", ""), "granite-code"),
(Input("modelscope://granite-code", "", ""), "granite-code"),
(
Input("modelscope://ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""),
"ibm-granite/granite-3b-code-base-2k-GGUF/granite-code",
),
(Input("ms://granite-code", "", ""), "granite-code"),
(Input("ollama://granite-code", "", ""), "granite-code"),
(Input("ollama.com/library/granite-code", "", ""), "granite-code"),
(
Input("ollama.com/library/ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""),
"ibm-granite/granite-3b-code-base-2k-GGUF/granite-code",
),
(Input("oci://granite-code", "", "podman"), "granite-code"),
(Input("docker://granite-code", "", "podman"), "granite-code"),
(Input("rlcr://granite-code", "", "podman"), "granite-code"),
(
Input(
f"{hf_granite_blob}/main/granite-3b-code-base.Q4_K_M.gguf",
"",
"",
),
"huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf",
),
(
Input(
f"{hf_granite_blob}/main/granite-3b-code-base.Q4_K_M.gguf",
"",
"",
),
"huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf",
),
(
Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""),
"/tmp/models/granite-3b-code-base.Q4_K_M.gguf",
),
(Input("granite-code", "huggingface", ""), "granite-code"),
(Input("granite-code", "ollama", ""), "granite-code"),
(Input("granite-code", "oci", ""), "granite-code"),
],
)
def test_prune_model_input(input: Input, expected: str):
args = ARGS(input.Engine)
pruned_model_input = TransportFactory(input.Model, args, input.Transport).prune_model_input()
assert pruned_model_input == expected
def test_transport_factory_passes_scheme_to_get_chat_provider(monkeypatch):
args = ARGS()
provider = OpenAIResponsesChatProvider("https://api.openai.com/v1")
captured: dict[str, str] = {}
def fake_get_chat_provider(scheme: str):
captured["scheme"] = scheme
return provider
monkeypatch.setattr(transport_factory_module, "get_chat_provider", fake_get_chat_provider)
transport = TransportFactory("openai://gpt-4o-mini", args).create()
assert captured["scheme"] == "openai"
assert isinstance(transport, APITransport)
assert transport.provider is provider