1
0
mirror of https://github.com/containers/ramalama.git synced 2026-02-05 15:47:26 +01:00

various typing and bug fixes

Signed-off-by: Ian Eaves <ian.k.eaves@gmail.com>
This commit is contained in:
Ian Eaves
2025-07-31 10:27:25 -05:00
parent c4ed1419f3
commit 9ec66d5604
23 changed files with 154 additions and 115 deletions

View File

@@ -111,9 +111,8 @@ docs:
.PHONY: lint
lint:
ifneq (,$(wildcard /usr/bin/python3))
/usr/bin/python3 -m compileall -q .
/usr/bin/python3 -m compileall -q -x '\.venv' .
endif
! grep -ri --exclude-dir ".venv" --exclude-dir "*/.venv" "#\!/usr/bin/python3" .
flake8 $(PROJECT_DIR) $(PYTHON_SCRIPTS)
shellcheck *.sh */*.sh */*/*.sh

View File

@@ -96,7 +96,7 @@ class RamaLamaShell(cmd.Cmd):
operational_args = ChatOperationalArgs()
super().__init__()
self.conversation_history = []
self.conversation_history: list[dict] = []
self.args = args
self.request_in_process = False
self.prompt = args.prefix

View File

@@ -16,7 +16,7 @@ from ramalama.config import COLOR_OPTIONS, SUPPORTED_RUNTIMES
try:
import argcomplete
suppressCompleter = argcomplete.completers.SuppressCompleter
suppressCompleter: type[argcomplete.completers.SuppressCompleter] | None = argcomplete.completers.SuppressCompleter
except Exception:
suppressCompleter = None
@@ -44,7 +44,6 @@ GENERATE_OPTIONS = ["quadlet", "kube", "quadlet/kube"]
class ParsedGenerateInput:
def __init__(self, gen_type: str, output_dir: str):
self.gen_type = gen_type
self.output_dir = output_dir
@@ -1226,7 +1225,6 @@ def inspect_cli(args):
def main():
def eprint(e, exit_code):
perror("Error: " + str(e).strip("'\""))
sys.exit(exit_code)

View File

@@ -13,8 +13,9 @@ import shutil
import string
import subprocess
import sys
from collections.abc import Callable, Iterable
from functools import lru_cache
from typing import TYPE_CHECKING, Callable, List, Literal, Protocol, cast, get_args
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast, get_args
import ramalama.amdkfd as amdkfd
from ramalama.logger import logger
@@ -230,7 +231,15 @@ def engine_version(engine: SUPPORTED_ENGINES) -> str:
return run_cmd(cmd_args).stdout.decode("utf-8").strip()
def load_cdi_yaml(stream) -> dict:
class CDI_DEVICE(TypedDict):
name: str
class CDI_RETURN_TYPE(TypedDict):
devices: list[CDI_DEVICE]
def load_cdi_yaml(stream: Iterable[str]) -> CDI_RETURN_TYPE:
# Returns a dict containing just the "devices" key, whose value is
# a list of dicts, each mapping the key "name" to a device name.
# For example: {'devices': [{'name': 'all'}]}
@@ -238,7 +247,7 @@ def load_cdi_yaml(stream) -> dict:
# under "devices" and the value of the "name" key being on the
# same line following a colon.
data = {"devices": []}
data: CDI_RETURN_TYPE = {"devices": []}
for line in stream:
if ':' in line:
key, value = line.split(':', 1)
@@ -247,7 +256,7 @@ def load_cdi_yaml(stream) -> dict:
return data
def load_cdi_config(spec_dirs: List[str]) -> dict | None:
def load_cdi_config(spec_dirs: list[str]) -> CDI_RETURN_TYPE | None:
# Loads the first YAML or JSON CDI configuration file found in the
# given directories."""
@@ -275,7 +284,7 @@ def load_cdi_config(spec_dirs: List[str]) -> dict | None:
return None
def find_in_cdi(devices: List[str]) -> tuple[List[str], List[str]]:
def find_in_cdi(devices: list[str]) -> tuple[list[str], list[str]]:
# Attempts to find a CDI configuration for each device in devices
# and returns a list of configured devices and a list of
# unconfigured devices.
@@ -327,11 +336,12 @@ def check_nvidia() -> Literal["cuda"] | None:
return None
smi_lines = result.stdout.splitlines()
parsed_lines = [[item.strip() for item in line.split(',')] for line in smi_lines if line]
parsed_lines: list[list[str]] = [[item.strip() for item in line.split(',')] for line in smi_lines if line]
if not parsed_lines:
return None
indices, uuids = zip(*parsed_lines) if parsed_lines else (tuple(), tuple())
indices, uuids = map(list, zip(*parsed_lines))
# Get the list of devices specified by CUDA_VISIBLE_DEVICES, if any
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
visible_devices = cuda_visible_devices.split(',') if cuda_visible_devices else []
@@ -342,14 +352,14 @@ def check_nvidia() -> Literal["cuda"] | None:
configured, unconfigured = find_in_cdi(visible_devices + ["all"])
if unconfigured and "all" not in configured:
if unconfigured and not (configured_has_all := "all" in configured):
perror(f"No CDI configuration found for {','.join(unconfigured)}")
perror("You can use the \"nvidia-ctk cdi generate\" command from the ")
perror("nvidia-container-toolkit to generate a CDI configuration.")
perror("See ramalama-cuda(7).")
return None
elif configured:
if "all" in configured:
if configured_has_all:
configured.remove("all")
if not configured:
configured = indices
@@ -442,7 +452,7 @@ def check_mthreads() -> Literal["musa"] | None:
return None
AccelType = Literal["asahi", "cuda", "cann", "hip", "intel", "musa"]
AccelType: TypeAlias = Literal["asahi", "cuda", "cann", "hip", "intel", "musa"]
def get_accel() -> AccelType | Literal["none"]:
@@ -474,7 +484,7 @@ def set_gpu_type_env_vars():
get_accel()
GPUEnvVar = Literal[
GPUEnvVar: TypeAlias = Literal[
"ASAHI_VISIBLE_DEVICES",
"ASCEND_VISIBLE_DEVICES",
"CUDA_VISIBLE_DEVICES",
@@ -486,10 +496,10 @@ GPUEnvVar = Literal[
def get_gpu_type_env_vars() -> dict[GPUEnvVar, str]:
return {k: os.environ[k] for k in get_args(GPUEnvVar) if k in os.environ}
return {k: v for k in get_args(GPUEnvVar) if (v := os.environ.get(k))}
AccelEnvVar = Literal[
AccelEnvVar: TypeAlias = Literal[
"CUDA_LAUNCH_BLOCKING",
"HSA_VISIBLE_DEVICES",
"HSA_OVERRIDE_GFX_VERSION",
@@ -498,7 +508,7 @@ AccelEnvVar = Literal[
def get_accel_env_vars() -> dict[GPUEnvVar | AccelEnvVar, str]:
gpu_env_vars: dict[GPUEnvVar, str] = get_gpu_type_env_vars()
accel_env_vars: dict[AccelEnvVar, str] = {k: os.environ[k] for k in get_args(AccelEnvVar) if k in os.environ}
accel_env_vars: dict[AccelEnvVar, str] = {k: v for k in get_args(AccelEnvVar) if (v := os.environ.get(k))}
return gpu_env_vars | accel_env_vars
@@ -599,7 +609,9 @@ class AccelImageArgsOtherRuntimeRAG(Protocol):
quiet: bool
AccelImageArgs = None | AccelImageArgsVLLMRuntime | AccelImageArgsOtherRuntime | AccelImageArgsOtherRuntimeRAG
AccelImageArgs: TypeAlias = (
None | AccelImageArgsVLLMRuntime | AccelImageArgsOtherRuntime | AccelImageArgsOtherRuntimeRAG
)
def accel_image(config: Config) -> str:
@@ -627,7 +639,7 @@ def accel_image(config: Config) -> str:
vers = minor_release()
should_pull = config.pull in ["always", "missing"] and not config.dryrun
if attempt_to_use_versioned(config.engine, image, vers, True, should_pull):
if config.engine and attempt_to_use_versioned(config.engine, image, vers, True, should_pull):
return f"{image}:{vers}"
return f"{image}:latest"

View File

@@ -3,19 +3,19 @@ import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Mapping
from typing import Any, Literal, Mapping, TypeAlias
from ramalama.common import available
from ramalama.layered_config import LayeredMixin, deep_merge
from ramalama.toml_parser import TOMLParser
PathStr = str
PathStr: TypeAlias = str
DEFAULT_PORT_RANGE: tuple[int, int] = (8080, 8090)
DEFAULT_PORT: int = DEFAULT_PORT_RANGE[0]
DEFAULT_IMAGE = "quay.io/ramalama/ramalama"
SUPPORTED_ENGINES = Literal["podman", "docker"] | PathStr
SUPPORTED_RUNTIMES = Literal["llama.cpp", "vllm", "mlx"]
COLOR_OPTIONS = Literal["auto", "always", "never"]
DEFAULT_IMAGE: str = "quay.io/ramalama/ramalama"
SUPPORTED_ENGINES: TypeAlias = Literal["podman", "docker"] | PathStr
SUPPORTED_RUNTIMES: TypeAlias = Literal["llama.cpp", "vllm", "mlx"]
COLOR_OPTIONS: TypeAlias = Literal["auto", "always", "never"]
def get_default_engine() -> SUPPORTED_ENGINES | None:
@@ -157,7 +157,7 @@ def load_env_config(env: Mapping[str, str] | None = None) -> dict[str, Any]:
if env is None:
env = os.environ
config = {}
config: dict[str, Any] = {}
for k, v in env.items():
if not k.startswith("RAMALAMA"):
continue

View File

@@ -23,7 +23,7 @@ class BaseFileManager(ABC):
return loader
@abstractmethod
def load(self):
def load(self, *args, **kwargs):
pass
@classmethod
@@ -121,12 +121,11 @@ class OpanAIChatAPIMessageBuilder:
if unsupported_files:
unsupported_files_warning(unsupported_files, list(self.supported_extensions()))
messages = []
messages: list[dict] = []
if text_files:
messages.append({"role": "system", "content": self.text_manager.load(text_files)})
if image_files:
message = {"role": "system", "content": []}
for content in self.image_manager.load(image_files):
message['content'].append({"type": "image_url", "image_url": {"url": content}})
content = [{"type": "image_url", "image_url": {"url": c}} for c in self.image_manager.load(image_files)]
message = {"role": "system", "content": content}
messages.append(message)
return messages

View File

@@ -64,7 +64,7 @@ class HFStyleRepository(ABC):
self.name = name
self.organization = organization
self.tag = tag
self.headers = {}
self.headers: dict = {}
self.blob_url = None
self.model_filename = None
self.model_hash = None

View File

@@ -5,6 +5,7 @@ import shlex
import socket
import sys
import time
from abc import ABC, abstractmethod
from typing import Optional
import ramalama.chat as chat
@@ -54,7 +55,6 @@ $(error)s"""
class NoRefFileFound(Exception):
def __init__(self, model: str, *args):
super().__init__(*args)
@@ -74,7 +74,10 @@ def trim_model_name(model):
return model
class ModelBase:
class ModelBase(ABC):
model: str
type: str
def __not_implemented_error(self, param):
return NotImplementedError(f"ramalama {param} for '{type(self).__name__}' not implemented")
@@ -90,24 +93,31 @@ class ModelBase:
def push(self, source_model, args):
raise self.__not_implemented_error("push")
@abstractmethod
def remove(self, args):
raise self.__not_implemented_error("rm")
@abstractmethod
def bench(self, args):
raise self.__not_implemented_error("bench")
@abstractmethod
def run(self, args):
raise self.__not_implemented_error("run")
@abstractmethod
def perplexity(self, args):
raise self.__not_implemented_error("perplexity")
@abstractmethod
def serve(self, args):
raise self.__not_implemented_error("serve")
@abstractmethod
def exists(self) -> bool:
raise self.__not_implemented_error("exists")
@abstractmethod
def inspect(self, args):
raise self.__not_implemented_error("inspect")
@@ -115,15 +125,14 @@ class ModelBase:
class Model(ModelBase):
"""Model super class"""
model = ""
type = "Model"
type: str = "Model"
def __init__(self, model, model_store_path):
def __init__(self, model: str, model_store_path: str):
self.model = model
split = self.model.rsplit("/", 1)
self.directory = split[0] if len(split) > 1 else ""
self.filename = split[1] if len(split) > 1 else split[0]
split: list[str] = self.model.rsplit("/", 1)
self.directory: str = split[0] if len(split) > 1 else ""
self.filename: str = split[1] if len(split) > 1 else split[0]
self._model_name: str
self._model_tag: str
@@ -685,7 +694,6 @@ class Model(ModelBase):
return exec_args
def generate_container_config(self, args, exec_args):
# Get the blob paths (src) and mounted paths (dest)
model_src_path = self._get_entry_model_path(False, False, args.dryrun)
chat_template_src_path = self._get_chat_template_path(False, False, args.dryrun)
@@ -776,7 +784,7 @@ class Model(ModelBase):
kube = Kube(self.model_name, model_paths, chat_template_paths, mmproj_paths, args, exec_args)
kube.generate().write(output_dir)
def inspect(self, args):
def inspect(self, args) -> None:
self.ensure_model_exists(args)
model_name = self.filename

View File

@@ -1,8 +1,9 @@
import copy
from typing import Callable, Tuple, Union
from collections.abc import Callable
from typing import TypeAlias
from urllib.parse import urlparse
from ramalama.arg_types import StoreArgs
from ramalama.arg_types import StoreArgType
from ramalama.common import rm_until_substring
from ramalama.config import CONFIG
from ramalama.huggingface import Huggingface
@@ -12,12 +13,14 @@ from ramalama.oci import OCI
from ramalama.ollama import Ollama
from ramalama.url import URL
CLASS_MODEL_TYPES: TypeAlias = Huggingface | Ollama | OCI | URL | ModelScope
class ModelFactory:
def __init__(
self,
model: str,
args: StoreArgs,
args: StoreArgType,
transport: str = "ollama",
ignore_stderr: bool = False,
):
@@ -28,20 +31,20 @@ class ModelFactory:
self.ignore_stderr = ignore_stderr
self.container = args.container
self.model_cls: type[Union[Huggingface, ModelScope, Ollama, OCI, URL]]
self.create: Callable[[], Union[Huggingface, ModelScope, Ollama, OCI, URL]]
self.model_cls, self.create = self.detect_model_model_type()
model_cls, _create = self.detect_model_model_type()
self.model_cls = model_cls
self._create = _create
self.pruned_model = self.prune_model_input()
self.draft_model = None
if getattr(args, 'model_draft', None):
dm_args = copy.deepcopy(args)
dm_args.model_draft = None
self.draft_model = ModelFactory(args.model_draft, dm_args, ignore_stderr=True).create()
def detect_model_model_type(
self,
) -> Tuple[type[Union[Huggingface, Ollama, OCI, URL]], Callable[[], Union[Huggingface, Ollama, OCI, URL]]]:
def detect_model_model_type(self) -> tuple[type[CLASS_MODEL_TYPES], Callable[[], CLASS_MODEL_TYPES]]:
for prefix in ["huggingface://", "hf://", "hf.co/"]:
if self.model.startswith(prefix):
return Huggingface, self.create_huggingface
@@ -89,6 +92,9 @@ class ModelFactory:
if self.model.startswith(t + "://"):
raise ValueError(f"{self.model} invalid: Only OCI Model types supported")
def create(self) -> CLASS_MODEL_TYPES:
return self._create()
def create_huggingface(self) -> Huggingface:
model = Huggingface(self.pruned_model, self.store_path)
model.draft_model = self.draft_model
@@ -108,7 +114,11 @@ class ModelFactory:
if not self.container:
raise ValueError("OCI containers cannot be used with the --nocontainer option.")
if self.engine is None:
raise ValueError("Constructing an OCI model factory requires an engine value")
self.validate_oci_model_input()
model = OCI(self.pruned_model, self.store_path, self.engine, self.ignore_stderr)
model.draft_model = self.draft_model
return model
@@ -119,13 +129,13 @@ class ModelFactory:
return model
def New(name, args, transport: str = None) -> Union[Huggingface | ModelScope | Ollama | OCI | URL]:
def New(name, args, transport: str | None = None) -> CLASS_MODEL_TYPES:
if transport is None:
transport = CONFIG.transport
return ModelFactory(name, args, transport=transport).create()
def Serve(name, args):
def Serve(name, args) -> None:
model = New(name, args)
try:
model.serve(args)

View File

@@ -53,7 +53,7 @@ class GGUFModelInfo(ModelInfoBase):
i = 0
for tensor in self.Tensors:
ret = ret + adjust_new_line(
f" {i}: {tensor.name, tensor.type.name, tensor.n_dimensions, tensor.offset}"
f" {i}: {tensor.name, tensor.type, tensor.n_dimensions, tensor.offset}"
)
i = i + 1

View File

@@ -1,7 +1,7 @@
import io
import struct
from enum import IntEnum
from typing import Any, Dict
from typing import Any, Dict, cast
from ramalama.endian import GGUFEndian
from ramalama.logger import logger
@@ -116,7 +116,7 @@ class GGUFInfoParser:
model: io.BufferedReader, model_endianness: GGUFEndian = GGUFEndian.LITTLE, length: int = -1
) -> str:
if length == -1:
length = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
length = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
raw = model.read(length)
if len(raw) < length:
@@ -144,12 +144,14 @@ class GGUFInfoParser:
@staticmethod
def read_value_type(model: io.BufferedReader, model_endianness: GGUFEndian) -> GGUFValueType:
value_type = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
value_type = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
return GGUFValueType(value_type)
@staticmethod
def read_value(model: io.BufferedReader, value_type: GGUFValueType, model_endianness: GGUFEndian) -> Any:
value = None
def read_value(
model: io.BufferedReader, value_type: GGUFValueType, model_endianness: GGUFEndian
) -> str | int | float | bool | list:
value: Any
if value_type in GGUF_NUMBER_FORMATS:
value = GGUFInfoParser.read_number(model, value_type, model_endianness)
elif value_type == GGUFValueType.BOOL:
@@ -158,13 +160,13 @@ class GGUFInfoParser:
value = GGUFInfoParser.read_string(model, model_endianness)
elif value_type == GGUFValueType.ARRAY:
array_type = GGUFInfoParser.read_value_type(model, model_endianness)
array_length = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
array_length = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
value = [GGUFInfoParser.read_value(model, array_type, model_endianness) for _ in range(array_length)]
if value is not None:
return value
else:
raise ParseError(f"Unknown type '{value_type}'")
return value
@staticmethod
def get_model_endianness(model_path: str) -> GGUFEndian:
# Pin model endianness to Little Endian by default.
@@ -176,7 +178,7 @@ class GGUFInfoParser:
if magic_number != GGUFModelInfo.MAGIC_NUMBER:
raise ParseError(f"Invalid GGUF magic number '{magic_number}'")
gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
gguf_version = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
if gguf_version & 0xFFFF == 0x0000:
model_endianness = GGUFEndian.BIG
@@ -191,10 +193,10 @@ class GGUFInfoParser:
if magic_number != GGUFModelInfo.MAGIC_NUMBER:
raise ParseError(f"Invalid GGUF magic number '{magic_number}'")
gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
gguf_version = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
tensor_count = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
metadata_kv_count = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
tensor_count = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
metadata_kv_count = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
metadata = {}
for _ in range(metadata_kv_count):
@@ -205,13 +207,17 @@ class GGUFInfoParser:
tensors: list[Tensor] = []
for _ in range(tensor_count):
name = GGUFInfoParser.read_string(model, model_endianness)
n_dimensions = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
n_dimensions = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
dimensions: list[int] = []
for _ in range(n_dimensions):
dimensions.append(GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
tensor_type = GGML_TYPE(GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
offset = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
tensors.append(Tensor(name, n_dimensions, dimensions, tensor_type, offset))
dim = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
dimensions.append(dim)
tensor_type = GGML_TYPE(
cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
)
offset = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
tensors.append(Tensor(name, n_dimensions, dimensions, tensor_type.name, offset))
return GGUFModelInfo(
model_name, model_registry, model_path, gguf_version, metadata, tensors, model_endianness

View File

@@ -71,7 +71,7 @@ class GlobalModelStore:
last_modified = os.path.getmtime(snapshot_file_path)
file_size = os.path.getsize(snapshot_file_path)
collected_files.append(
ModelFile(snapshot_file, last_modified, file_size, is_partially_downloaded)
ModelFile(snapshot_file.name, last_modified, file_size, is_partially_downloaded)
)
models[model_name] = collected_files

View File

@@ -7,7 +7,7 @@ import re
import sys
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Optional
from typing import Dict, Optional, cast
class NodeType(Enum):
@@ -298,7 +298,6 @@ def parse_go_template(content: str) -> list[Node]:
artificial=False,
)
root_nodes.append(content_node)
return root_nodes
@@ -353,7 +352,6 @@ def translate_continue_nodes(root_nodes: list[Node]) -> list[Node]:
parent_node.children = parent_node.children[:start_index] + [if_node, end_node]
find_continue_nodes(root_nodes)
skip_variable = "$should_continue"
for continue_node in continue_nodes:
# find start of loop to initialize continue skip variable
@@ -376,6 +374,7 @@ def translate_continue_nodes(root_nodes: list[Node]) -> list[Node]:
children=[],
artificial=True,
)
for_node.next = cast(Node, for_node.next)
for_node.next.prev = initial_set_node
for_node.next = initial_set_node
for_node.children = [initial_set_node] + for_node.children
@@ -553,13 +552,13 @@ def go_to_jinja(content: str) -> str:
if m is None:
return ""
if node.open_scope_node.type in [NodeType.IF, NodeType.ELIF, NodeType.ELSE]:
if node.open_scope_node.type in [NodeType.IF, NodeType.ELIF, NodeType.ELSE]: # type: ignore
return (
node.content[: m.start(1)].replace(GO_SYMBOL_OPEN_BRACKETS, JINJA_SYMBOL_OPEN_BRACKETS)
+ "endif"
+ node.content[m.end(1) :].replace(GO_SYMBOL_CLOSE_BRACKETS, JINJA_SYMBOL_CLOSE_BRACKETS)
)
elif node.open_scope_node.type == NodeType.RANGE:
elif node.open_scope_node.type == NodeType.RANGE: # type: ignore
loop_vars.pop()
if loop_index_vars:
loop_index_vars.pop()

View File

@@ -15,7 +15,7 @@ class RefFile:
CHAT_TEMPLATE_SUFFIX = "chat"
MMPROJ_SUFFIX = "mmproj"
def __init__(self):
def __init__(self) -> None:
self.hash: str = ""
self.filenames: list[str] = []
self.model_name: str = ""
@@ -133,7 +133,6 @@ def migrate_reffile_to_refjsonfile(ref_file_path: str, snapshot_directory: str)
class StoreFileType(StrEnum):
GGUF_MODEL = "gguf"
MMPROJ = "mmproj"
CHAT_TEMPLATE = "chat_template"
@@ -152,7 +151,6 @@ class StoreFileType(StrEnum):
@dataclass
class StoreFile:
hash: str
name: str
type: StoreFileType
@@ -160,7 +158,6 @@ class StoreFile:
@dataclass
class RefJSONFile:
hash: str
path: str
files: list[StoreFile]

View File

@@ -1,6 +1,6 @@
import os
from enum import IntEnum
from typing import Dict
from typing import Dict, Sequence
from ramalama.common import generate_sha256
from ramalama.http_client import download_file
@@ -60,7 +60,7 @@ class LocalSnapshotFile(SnapshotFile):
):
super().__init__(
"",
"",
{},
generate_sha256(content),
name,
type,
@@ -77,7 +77,7 @@ class LocalSnapshotFile(SnapshotFile):
return os.path.relpath(blob_file_path, start=snapshot_dir)
def validate_snapshot_files(snapshot_files: list[SnapshotFile]):
def validate_snapshot_files(snapshot_files: Sequence[SnapshotFile]):
chat_template_files = []
mmproj_files = []
for file in snapshot_files:

View File

@@ -4,7 +4,7 @@ import urllib.error
from collections import Counter
from http import HTTPStatus
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Sequence, Tuple
import ramalama.model_store.go2jinja as go2jinja
from ramalama.common import perror, sanitize_filename, verify_checksum
@@ -97,9 +97,12 @@ class ModelStore:
return RefJSONFile.from_path(ref_file_path)
def update_ref_file(
self, model_tag: str, snapshot_hash: str = "", snapshot_files: list[SnapshotFile] = []
self, model_tag: str, snapshot_hash: str = "", snapshot_files: Optional[list[SnapshotFile]] = None
) -> Optional[RefJSONFile]:
ref_file: RefJSONFile = self.get_ref_file(model_tag)
if snapshot_files is None:
snapshot_files = []
ref_file: RefJSONFile | None = self.get_ref_file(model_tag)
if ref_file is None:
return None
@@ -154,9 +157,9 @@ class ModelStore:
)
def get_cached_files(self, model_tag: str) -> Tuple[str, list[str], bool]:
cached_files = []
cached_files: list[str] = []
ref_file: RefJSONFile = self.get_ref_file(model_tag)
ref_file: RefJSONFile | None = self.get_ref_file(model_tag)
if ref_file is None:
return ("", cached_files, False)
@@ -182,8 +185,10 @@ class ModelStore:
snapshot_directory = self.get_snapshot_directory(snapshot_hash)
os.makedirs(snapshot_directory, exist_ok=True)
def _download_snapshot_files(self, model_tag: str, snapshot_hash: str, snapshot_files: list[SnapshotFile]):
ref_file = self.get_ref_file(model_tag)
def _download_snapshot_files(self, model_tag: str, snapshot_hash: str, snapshot_files: Sequence[SnapshotFile]):
ref_file: None | RefJSONFile = self.get_ref_file(model_tag)
if ref_file is None:
raise ValueError("Cannot download snapshots without a valid ref file.")
for file in snapshot_files:
dest_path = self.get_blob_file_path(file.hash)
@@ -223,8 +228,8 @@ class ModelStore:
if file.type == SnapshotFileType.ChatTemplate:
chat_template_file_path = self.get_blob_file_path(file.hash)
chat_template = ""
with open(chat_template_file_path, "r") as file:
chat_template = file.read()
with open(chat_template_file_path, "r") as template_file:
chat_template = template_file.read()
if not go2jinja.is_go_template(chat_template):
return
@@ -321,7 +326,7 @@ class ModelStore:
self.remove_snapshot(model_tag)
raise ex
def update_snapshot(self, model_tag: str, snapshot_hash: str, new_snapshot_files: list[SnapshotFile]) -> bool:
def update_snapshot(self, model_tag: str, snapshot_hash: str, new_snapshot_files: Sequence[SnapshotFile]) -> bool:
validate_snapshot_files(new_snapshot_files)
snapshot_hash = sanitize_filename(snapshot_hash)
@@ -363,7 +368,7 @@ class ModelStore:
for entry in os.listdir(self.refs_directory)
if os.path.isfile(os.path.join(self.refs_directory, entry))
]
refs = [self.get_ref_file(tag) for tag in model_tags]
refs = [ref for tag in model_tags if (ref := self.get_ref_file(tag))]
blob_refcounts = Counter(file.name for ref in refs for file in ref.files)

View File

@@ -130,10 +130,11 @@ def list_models(args: EngineArgType):
class OCI(Model):
def __init__(self, model, model_store_path, conman, ignore_stderr=False):
type = "OCI"
def __init__(self, model: str, model_store_path: str, conman: str, ignore_stderr: bool = False):
super().__init__(model, model_store_path)
self.type = "OCI"
if not conman:
raise ValueError("RamaLama OCI Images requires a container engine")

View File

@@ -139,12 +139,12 @@ class OllamaRepository:
class Ollama(Model):
def __init__(self, model, model_store_path):
def __init__(self, model, model_store_path) -> None:
super().__init__(model, model_store_path)
self.type = "Ollama"
def extract_model_identifiers(self):
def extract_model_identifiers(self) -> tuple[str, str, str]:
model_name, model_tag, model_organization = super().extract_model_identifiers()
# use the ollama default namespace if no model organization has been identified
@@ -152,11 +152,11 @@ class Ollama(Model):
model_organization = "library"
return model_name, model_tag, model_organization
def resolve_model(self):
def resolve_model(self) -> str:
name, tag, organization = self.extract_model_identifiers()
return f"ollama://{organization}/{name}:{tag}"
def pull(self, args):
def pull(self, args) -> None:
name, tag, organization = self.extract_model_identifiers()
_, cached_files, all = self.model_store.get_cached_files(tag)
if all:

View File

@@ -12,17 +12,17 @@ INPUT_DIR = "/docs"
class Rag:
model = ""
target = ""
urls = []
model: str = ""
target: str = ""
urls: list[str] = []
def __init__(self, target):
def __init__(self, target: str):
if not target.islower():
raise ValueError(f"invalid reference format: repository name '{target}' must be lowercase")
self.target = target
set_accel_env_vars()
def build(self, source, target, args):
def build(self, source: str, target: str, args):
perror(f"\nBuilding {target} ...")
contextdir = os.path.dirname(source)
src = os.path.basename(source)
@@ -60,7 +60,7 @@ COPY {src} /vector.db
)
return imageid
def _handle_paths(self, path):
def _handle_paths(self, path: str):
"""Adds a volume mount if path exists, otherwise add URL."""
parsed = urlparse(path)
if parsed.scheme in ["file", ""] and parsed.netloc == "":
@@ -135,7 +135,7 @@ COPY {src} /vector.db
shutil.rmtree(ragdb.name, ignore_errors=True)
def rag_image(image) -> str:
def rag_image(image: str) -> str:
imagespec = image.split(":")
rag_image = f"{imagespec[0]}-rag"
if len(imagespec) > 1:

View File

@@ -104,7 +104,7 @@ class URL(Model):
return files
def pull(self, _):
def pull(self, _) -> None:
name, tag, _ = self.extract_model_identifiers()
_, _, all_files = self.model_store.get_cached_files(tag)
if all_files:

View File

@@ -313,6 +313,7 @@ class TestOpanAIChatAPIMessageBuilder:
assert isinstance(messages[1]["content"], list)
assert len(messages[1]["content"]) == 1
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
def test_builder_load_no_supported_files(self):
"""Test loading directory with no supported files."""
with tempfile.TemporaryDirectory() as tmp_dir:

View File

@@ -70,6 +70,7 @@ class TestFileUploadChatIntegration:
assert "readme.md" in system_message["content"]
assert "<!--start_document" in system_message["content"]
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
@patch('urllib.request.urlopen')
def test_chat_with_file_input_no_files(self, mock_urlopen):
"""Test chat functionality with input directory containing no supported files."""
@@ -385,6 +386,7 @@ class TestImageUploadChatIntegration:
for item in image_msg["content"]
)
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
@patch('urllib.request.urlopen')
def test_chat_with_image_input_unsupported_image_types(self, mock_urlopen):
"""Test chat functionality with unsupported image file types."""

View File

@@ -182,6 +182,7 @@ class TestFileUploadWithDataFiles:
assert "sample.md" in content
assert "sample.json" in content
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
def test_unsupported_file_handling(self, data_dir):
"""Test that unsupported files are handled correctly."""
@@ -310,6 +311,7 @@ class TestImageUploadWithDataFiles:
assert all("data:image/" in item["image_url"]["url"] for item in messages[0]["content"])
assert all("base64," in item["image_url"]["url"] for item in messages[0]["content"])
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
def test_image_unsupported_file_handling(self, data_dir):
"""Test that unsupported image files are handled correctly."""
with tempfile.TemporaryDirectory() as tmp_dir: