mirror of
https://github.com/containers/ramalama.git
synced 2026-02-05 06:46:39 +01:00
various typing and bug fixes
Signed-off-by: Ian Eaves <ian.k.eaves@gmail.com>
This commit is contained in:
3
Makefile
3
Makefile
@@ -111,9 +111,8 @@ docs:
|
||||
.PHONY: lint
|
||||
lint:
|
||||
ifneq (,$(wildcard /usr/bin/python3))
|
||||
/usr/bin/python3 -m compileall -q .
|
||||
/usr/bin/python3 -m compileall -q -x '\.venv' .
|
||||
endif
|
||||
|
||||
! grep -ri --exclude-dir ".venv" --exclude-dir "*/.venv" "#\!/usr/bin/python3" .
|
||||
flake8 $(PROJECT_DIR) $(PYTHON_SCRIPTS)
|
||||
shellcheck *.sh */*.sh */*/*.sh
|
||||
|
||||
@@ -96,7 +96,7 @@ class RamaLamaShell(cmd.Cmd):
|
||||
operational_args = ChatOperationalArgs()
|
||||
|
||||
super().__init__()
|
||||
self.conversation_history = []
|
||||
self.conversation_history: list[dict] = []
|
||||
self.args = args
|
||||
self.request_in_process = False
|
||||
self.prompt = args.prefix
|
||||
|
||||
@@ -16,7 +16,7 @@ from ramalama.config import COLOR_OPTIONS, SUPPORTED_RUNTIMES
|
||||
try:
|
||||
import argcomplete
|
||||
|
||||
suppressCompleter = argcomplete.completers.SuppressCompleter
|
||||
suppressCompleter: type[argcomplete.completers.SuppressCompleter] | None = argcomplete.completers.SuppressCompleter
|
||||
except Exception:
|
||||
suppressCompleter = None
|
||||
|
||||
@@ -44,7 +44,6 @@ GENERATE_OPTIONS = ["quadlet", "kube", "quadlet/kube"]
|
||||
|
||||
|
||||
class ParsedGenerateInput:
|
||||
|
||||
def __init__(self, gen_type: str, output_dir: str):
|
||||
self.gen_type = gen_type
|
||||
self.output_dir = output_dir
|
||||
@@ -1226,7 +1225,6 @@ def inspect_cli(args):
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
def eprint(e, exit_code):
|
||||
perror("Error: " + str(e).strip("'\""))
|
||||
sys.exit(exit_code)
|
||||
|
||||
@@ -13,8 +13,9 @@ import shutil
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
from collections.abc import Callable, Iterable
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Protocol, cast, get_args
|
||||
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast, get_args
|
||||
|
||||
import ramalama.amdkfd as amdkfd
|
||||
from ramalama.logger import logger
|
||||
@@ -230,7 +231,15 @@ def engine_version(engine: SUPPORTED_ENGINES) -> str:
|
||||
return run_cmd(cmd_args).stdout.decode("utf-8").strip()
|
||||
|
||||
|
||||
def load_cdi_yaml(stream) -> dict:
|
||||
class CDI_DEVICE(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
class CDI_RETURN_TYPE(TypedDict):
|
||||
devices: list[CDI_DEVICE]
|
||||
|
||||
|
||||
def load_cdi_yaml(stream: Iterable[str]) -> CDI_RETURN_TYPE:
|
||||
# Returns a dict containing just the "devices" key, whose value is
|
||||
# a list of dicts, each mapping the key "name" to a device name.
|
||||
# For example: {'devices': [{'name': 'all'}]}
|
||||
@@ -238,7 +247,7 @@ def load_cdi_yaml(stream) -> dict:
|
||||
# under "devices" and the value of the "name" key being on the
|
||||
# same line following a colon.
|
||||
|
||||
data = {"devices": []}
|
||||
data: CDI_RETURN_TYPE = {"devices": []}
|
||||
for line in stream:
|
||||
if ':' in line:
|
||||
key, value = line.split(':', 1)
|
||||
@@ -247,7 +256,7 @@ def load_cdi_yaml(stream) -> dict:
|
||||
return data
|
||||
|
||||
|
||||
def load_cdi_config(spec_dirs: List[str]) -> dict | None:
|
||||
def load_cdi_config(spec_dirs: list[str]) -> CDI_RETURN_TYPE | None:
|
||||
# Loads the first YAML or JSON CDI configuration file found in the
|
||||
# given directories."""
|
||||
|
||||
@@ -275,7 +284,7 @@ def load_cdi_config(spec_dirs: List[str]) -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def find_in_cdi(devices: List[str]) -> tuple[List[str], List[str]]:
|
||||
def find_in_cdi(devices: list[str]) -> tuple[list[str], list[str]]:
|
||||
# Attempts to find a CDI configuration for each device in devices
|
||||
# and returns a list of configured devices and a list of
|
||||
# unconfigured devices.
|
||||
@@ -327,11 +336,12 @@ def check_nvidia() -> Literal["cuda"] | None:
|
||||
return None
|
||||
|
||||
smi_lines = result.stdout.splitlines()
|
||||
parsed_lines = [[item.strip() for item in line.split(',')] for line in smi_lines if line]
|
||||
parsed_lines: list[list[str]] = [[item.strip() for item in line.split(',')] for line in smi_lines if line]
|
||||
|
||||
if not parsed_lines:
|
||||
return None
|
||||
|
||||
indices, uuids = zip(*parsed_lines) if parsed_lines else (tuple(), tuple())
|
||||
indices, uuids = map(list, zip(*parsed_lines))
|
||||
# Get the list of devices specified by CUDA_VISIBLE_DEVICES, if any
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
||||
visible_devices = cuda_visible_devices.split(',') if cuda_visible_devices else []
|
||||
@@ -342,14 +352,14 @@ def check_nvidia() -> Literal["cuda"] | None:
|
||||
|
||||
configured, unconfigured = find_in_cdi(visible_devices + ["all"])
|
||||
|
||||
if unconfigured and "all" not in configured:
|
||||
if unconfigured and not (configured_has_all := "all" in configured):
|
||||
perror(f"No CDI configuration found for {','.join(unconfigured)}")
|
||||
perror("You can use the \"nvidia-ctk cdi generate\" command from the ")
|
||||
perror("nvidia-container-toolkit to generate a CDI configuration.")
|
||||
perror("See ramalama-cuda(7).")
|
||||
return None
|
||||
elif configured:
|
||||
if "all" in configured:
|
||||
if configured_has_all:
|
||||
configured.remove("all")
|
||||
if not configured:
|
||||
configured = indices
|
||||
@@ -442,7 +452,7 @@ def check_mthreads() -> Literal["musa"] | None:
|
||||
return None
|
||||
|
||||
|
||||
AccelType = Literal["asahi", "cuda", "cann", "hip", "intel", "musa"]
|
||||
AccelType: TypeAlias = Literal["asahi", "cuda", "cann", "hip", "intel", "musa"]
|
||||
|
||||
|
||||
def get_accel() -> AccelType | Literal["none"]:
|
||||
@@ -474,7 +484,7 @@ def set_gpu_type_env_vars():
|
||||
get_accel()
|
||||
|
||||
|
||||
GPUEnvVar = Literal[
|
||||
GPUEnvVar: TypeAlias = Literal[
|
||||
"ASAHI_VISIBLE_DEVICES",
|
||||
"ASCEND_VISIBLE_DEVICES",
|
||||
"CUDA_VISIBLE_DEVICES",
|
||||
@@ -486,10 +496,10 @@ GPUEnvVar = Literal[
|
||||
|
||||
|
||||
def get_gpu_type_env_vars() -> dict[GPUEnvVar, str]:
|
||||
return {k: os.environ[k] for k in get_args(GPUEnvVar) if k in os.environ}
|
||||
return {k: v for k in get_args(GPUEnvVar) if (v := os.environ.get(k))}
|
||||
|
||||
|
||||
AccelEnvVar = Literal[
|
||||
AccelEnvVar: TypeAlias = Literal[
|
||||
"CUDA_LAUNCH_BLOCKING",
|
||||
"HSA_VISIBLE_DEVICES",
|
||||
"HSA_OVERRIDE_GFX_VERSION",
|
||||
@@ -498,7 +508,7 @@ AccelEnvVar = Literal[
|
||||
|
||||
def get_accel_env_vars() -> dict[GPUEnvVar | AccelEnvVar, str]:
|
||||
gpu_env_vars: dict[GPUEnvVar, str] = get_gpu_type_env_vars()
|
||||
accel_env_vars: dict[AccelEnvVar, str] = {k: os.environ[k] for k in get_args(AccelEnvVar) if k in os.environ}
|
||||
accel_env_vars: dict[AccelEnvVar, str] = {k: v for k in get_args(AccelEnvVar) if (v := os.environ.get(k))}
|
||||
return gpu_env_vars | accel_env_vars
|
||||
|
||||
|
||||
@@ -599,7 +609,9 @@ class AccelImageArgsOtherRuntimeRAG(Protocol):
|
||||
quiet: bool
|
||||
|
||||
|
||||
AccelImageArgs = None | AccelImageArgsVLLMRuntime | AccelImageArgsOtherRuntime | AccelImageArgsOtherRuntimeRAG
|
||||
AccelImageArgs: TypeAlias = (
|
||||
None | AccelImageArgsVLLMRuntime | AccelImageArgsOtherRuntime | AccelImageArgsOtherRuntimeRAG
|
||||
)
|
||||
|
||||
|
||||
def accel_image(config: Config) -> str:
|
||||
@@ -627,7 +639,7 @@ def accel_image(config: Config) -> str:
|
||||
vers = minor_release()
|
||||
|
||||
should_pull = config.pull in ["always", "missing"] and not config.dryrun
|
||||
if attempt_to_use_versioned(config.engine, image, vers, True, should_pull):
|
||||
if config.engine and attempt_to_use_versioned(config.engine, image, vers, True, should_pull):
|
||||
return f"{image}:{vers}"
|
||||
|
||||
return f"{image}:latest"
|
||||
|
||||
@@ -3,19 +3,19 @@ import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Mapping
|
||||
from typing import Any, Literal, Mapping, TypeAlias
|
||||
|
||||
from ramalama.common import available
|
||||
from ramalama.layered_config import LayeredMixin, deep_merge
|
||||
from ramalama.toml_parser import TOMLParser
|
||||
|
||||
PathStr = str
|
||||
PathStr: TypeAlias = str
|
||||
DEFAULT_PORT_RANGE: tuple[int, int] = (8080, 8090)
|
||||
DEFAULT_PORT: int = DEFAULT_PORT_RANGE[0]
|
||||
DEFAULT_IMAGE = "quay.io/ramalama/ramalama"
|
||||
SUPPORTED_ENGINES = Literal["podman", "docker"] | PathStr
|
||||
SUPPORTED_RUNTIMES = Literal["llama.cpp", "vllm", "mlx"]
|
||||
COLOR_OPTIONS = Literal["auto", "always", "never"]
|
||||
DEFAULT_IMAGE: str = "quay.io/ramalama/ramalama"
|
||||
SUPPORTED_ENGINES: TypeAlias = Literal["podman", "docker"] | PathStr
|
||||
SUPPORTED_RUNTIMES: TypeAlias = Literal["llama.cpp", "vllm", "mlx"]
|
||||
COLOR_OPTIONS: TypeAlias = Literal["auto", "always", "never"]
|
||||
|
||||
|
||||
def get_default_engine() -> SUPPORTED_ENGINES | None:
|
||||
@@ -157,7 +157,7 @@ def load_env_config(env: Mapping[str, str] | None = None) -> dict[str, Any]:
|
||||
if env is None:
|
||||
env = os.environ
|
||||
|
||||
config = {}
|
||||
config: dict[str, Any] = {}
|
||||
for k, v in env.items():
|
||||
if not k.startswith("RAMALAMA"):
|
||||
continue
|
||||
|
||||
@@ -23,7 +23,7 @@ class BaseFileManager(ABC):
|
||||
return loader
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
def load(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@@ -121,12 +121,11 @@ class OpanAIChatAPIMessageBuilder:
|
||||
if unsupported_files:
|
||||
unsupported_files_warning(unsupported_files, list(self.supported_extensions()))
|
||||
|
||||
messages = []
|
||||
messages: list[dict] = []
|
||||
if text_files:
|
||||
messages.append({"role": "system", "content": self.text_manager.load(text_files)})
|
||||
if image_files:
|
||||
message = {"role": "system", "content": []}
|
||||
for content in self.image_manager.load(image_files):
|
||||
message['content'].append({"type": "image_url", "image_url": {"url": content}})
|
||||
content = [{"type": "image_url", "image_url": {"url": c}} for c in self.image_manager.load(image_files)]
|
||||
message = {"role": "system", "content": content}
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
@@ -64,7 +64,7 @@ class HFStyleRepository(ABC):
|
||||
self.name = name
|
||||
self.organization = organization
|
||||
self.tag = tag
|
||||
self.headers = {}
|
||||
self.headers: dict = {}
|
||||
self.blob_url = None
|
||||
self.model_filename = None
|
||||
self.model_hash = None
|
||||
|
||||
@@ -5,6 +5,7 @@ import shlex
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import ramalama.chat as chat
|
||||
@@ -54,7 +55,6 @@ $(error)s"""
|
||||
|
||||
|
||||
class NoRefFileFound(Exception):
|
||||
|
||||
def __init__(self, model: str, *args):
|
||||
super().__init__(*args)
|
||||
|
||||
@@ -74,7 +74,10 @@ def trim_model_name(model):
|
||||
return model
|
||||
|
||||
|
||||
class ModelBase:
|
||||
class ModelBase(ABC):
|
||||
model: str
|
||||
type: str
|
||||
|
||||
def __not_implemented_error(self, param):
|
||||
return NotImplementedError(f"ramalama {param} for '{type(self).__name__}' not implemented")
|
||||
|
||||
@@ -90,24 +93,31 @@ class ModelBase:
|
||||
def push(self, source_model, args):
|
||||
raise self.__not_implemented_error("push")
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, args):
|
||||
raise self.__not_implemented_error("rm")
|
||||
|
||||
@abstractmethod
|
||||
def bench(self, args):
|
||||
raise self.__not_implemented_error("bench")
|
||||
|
||||
@abstractmethod
|
||||
def run(self, args):
|
||||
raise self.__not_implemented_error("run")
|
||||
|
||||
@abstractmethod
|
||||
def perplexity(self, args):
|
||||
raise self.__not_implemented_error("perplexity")
|
||||
|
||||
@abstractmethod
|
||||
def serve(self, args):
|
||||
raise self.__not_implemented_error("serve")
|
||||
|
||||
@abstractmethod
|
||||
def exists(self) -> bool:
|
||||
raise self.__not_implemented_error("exists")
|
||||
|
||||
@abstractmethod
|
||||
def inspect(self, args):
|
||||
raise self.__not_implemented_error("inspect")
|
||||
|
||||
@@ -115,15 +125,14 @@ class ModelBase:
|
||||
class Model(ModelBase):
|
||||
"""Model super class"""
|
||||
|
||||
model = ""
|
||||
type = "Model"
|
||||
type: str = "Model"
|
||||
|
||||
def __init__(self, model, model_store_path):
|
||||
def __init__(self, model: str, model_store_path: str):
|
||||
self.model = model
|
||||
|
||||
split = self.model.rsplit("/", 1)
|
||||
self.directory = split[0] if len(split) > 1 else ""
|
||||
self.filename = split[1] if len(split) > 1 else split[0]
|
||||
split: list[str] = self.model.rsplit("/", 1)
|
||||
self.directory: str = split[0] if len(split) > 1 else ""
|
||||
self.filename: str = split[1] if len(split) > 1 else split[0]
|
||||
|
||||
self._model_name: str
|
||||
self._model_tag: str
|
||||
@@ -419,7 +428,7 @@ class Model(ModelBase):
|
||||
chat.chat(args)
|
||||
break
|
||||
else:
|
||||
logger.debug(f"MLX server not ready, waiting... (attempt {i+1}/{max_retries})")
|
||||
logger.debug(f"MLX server not ready, waiting... (attempt {i + 1}/{max_retries})")
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
||||
@@ -428,7 +437,7 @@ class Model(ModelBase):
|
||||
perror(f"Error: Failed to connect to MLX server after {max_retries} attempts: {e}")
|
||||
self._cleanup_server_process(args.pid2kill)
|
||||
raise e
|
||||
logger.debug(f"Connection attempt failed, retrying... (attempt {i+1}/{max_retries}): {e}")
|
||||
logger.debug(f"Connection attempt failed, retrying... (attempt {i + 1}/{max_retries}): {e}")
|
||||
time.sleep(3)
|
||||
|
||||
args.initial_connection = False
|
||||
@@ -685,7 +694,6 @@ class Model(ModelBase):
|
||||
return exec_args
|
||||
|
||||
def generate_container_config(self, args, exec_args):
|
||||
|
||||
# Get the blob paths (src) and mounted paths (dest)
|
||||
model_src_path = self._get_entry_model_path(False, False, args.dryrun)
|
||||
chat_template_src_path = self._get_chat_template_path(False, False, args.dryrun)
|
||||
@@ -776,7 +784,7 @@ class Model(ModelBase):
|
||||
kube = Kube(self.model_name, model_paths, chat_template_paths, mmproj_paths, args, exec_args)
|
||||
kube.generate().write(output_dir)
|
||||
|
||||
def inspect(self, args):
|
||||
def inspect(self, args) -> None:
|
||||
self.ensure_model_exists(args)
|
||||
|
||||
model_name = self.filename
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import copy
|
||||
from typing import Callable, Tuple, Union
|
||||
from collections.abc import Callable
|
||||
from typing import TypeAlias
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ramalama.arg_types import StoreArgs
|
||||
from ramalama.arg_types import StoreArgType
|
||||
from ramalama.common import rm_until_substring
|
||||
from ramalama.config import CONFIG
|
||||
from ramalama.huggingface import Huggingface
|
||||
@@ -12,12 +13,14 @@ from ramalama.oci import OCI
|
||||
from ramalama.ollama import Ollama
|
||||
from ramalama.url import URL
|
||||
|
||||
CLASS_MODEL_TYPES: TypeAlias = Huggingface | Ollama | OCI | URL | ModelScope
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
args: StoreArgs,
|
||||
args: StoreArgType,
|
||||
transport: str = "ollama",
|
||||
ignore_stderr: bool = False,
|
||||
):
|
||||
@@ -28,20 +31,20 @@ class ModelFactory:
|
||||
self.ignore_stderr = ignore_stderr
|
||||
self.container = args.container
|
||||
|
||||
self.model_cls: type[Union[Huggingface, ModelScope, Ollama, OCI, URL]]
|
||||
self.create: Callable[[], Union[Huggingface, ModelScope, Ollama, OCI, URL]]
|
||||
self.model_cls, self.create = self.detect_model_model_type()
|
||||
model_cls, _create = self.detect_model_model_type()
|
||||
|
||||
self.model_cls = model_cls
|
||||
self._create = _create
|
||||
|
||||
self.pruned_model = self.prune_model_input()
|
||||
self.draft_model = None
|
||||
|
||||
if getattr(args, 'model_draft', None):
|
||||
dm_args = copy.deepcopy(args)
|
||||
dm_args.model_draft = None
|
||||
self.draft_model = ModelFactory(args.model_draft, dm_args, ignore_stderr=True).create()
|
||||
|
||||
def detect_model_model_type(
|
||||
self,
|
||||
) -> Tuple[type[Union[Huggingface, Ollama, OCI, URL]], Callable[[], Union[Huggingface, Ollama, OCI, URL]]]:
|
||||
def detect_model_model_type(self) -> tuple[type[CLASS_MODEL_TYPES], Callable[[], CLASS_MODEL_TYPES]]:
|
||||
for prefix in ["huggingface://", "hf://", "hf.co/"]:
|
||||
if self.model.startswith(prefix):
|
||||
return Huggingface, self.create_huggingface
|
||||
@@ -89,6 +92,9 @@ class ModelFactory:
|
||||
if self.model.startswith(t + "://"):
|
||||
raise ValueError(f"{self.model} invalid: Only OCI Model types supported")
|
||||
|
||||
def create(self) -> CLASS_MODEL_TYPES:
|
||||
return self._create()
|
||||
|
||||
def create_huggingface(self) -> Huggingface:
|
||||
model = Huggingface(self.pruned_model, self.store_path)
|
||||
model.draft_model = self.draft_model
|
||||
@@ -108,7 +114,11 @@ class ModelFactory:
|
||||
if not self.container:
|
||||
raise ValueError("OCI containers cannot be used with the --nocontainer option.")
|
||||
|
||||
if self.engine is None:
|
||||
raise ValueError("Constructing an OCI model factory requires an engine value")
|
||||
|
||||
self.validate_oci_model_input()
|
||||
|
||||
model = OCI(self.pruned_model, self.store_path, self.engine, self.ignore_stderr)
|
||||
model.draft_model = self.draft_model
|
||||
return model
|
||||
@@ -119,13 +129,13 @@ class ModelFactory:
|
||||
return model
|
||||
|
||||
|
||||
def New(name, args, transport: str = None) -> Union[Huggingface | ModelScope | Ollama | OCI | URL]:
|
||||
def New(name, args, transport: str | None = None) -> CLASS_MODEL_TYPES:
|
||||
if transport is None:
|
||||
transport = CONFIG.transport
|
||||
return ModelFactory(name, args, transport=transport).create()
|
||||
|
||||
|
||||
def Serve(name, args):
|
||||
def Serve(name, args) -> None:
|
||||
model = New(name, args)
|
||||
try:
|
||||
model.serve(args)
|
||||
|
||||
@@ -53,7 +53,7 @@ class GGUFModelInfo(ModelInfoBase):
|
||||
i = 0
|
||||
for tensor in self.Tensors:
|
||||
ret = ret + adjust_new_line(
|
||||
f" {i}: {tensor.name, tensor.type.name, tensor.n_dimensions, tensor.offset}"
|
||||
f" {i}: {tensor.name, tensor.type, tensor.n_dimensions, tensor.offset}"
|
||||
)
|
||||
i = i + 1
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import io
|
||||
import struct
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
from ramalama.endian import GGUFEndian
|
||||
from ramalama.logger import logger
|
||||
@@ -116,7 +116,7 @@ class GGUFInfoParser:
|
||||
model: io.BufferedReader, model_endianness: GGUFEndian = GGUFEndian.LITTLE, length: int = -1
|
||||
) -> str:
|
||||
if length == -1:
|
||||
length = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
|
||||
length = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
|
||||
|
||||
raw = model.read(length)
|
||||
if len(raw) < length:
|
||||
@@ -144,12 +144,14 @@ class GGUFInfoParser:
|
||||
|
||||
@staticmethod
|
||||
def read_value_type(model: io.BufferedReader, model_endianness: GGUFEndian) -> GGUFValueType:
|
||||
value_type = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
|
||||
value_type = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
|
||||
return GGUFValueType(value_type)
|
||||
|
||||
@staticmethod
|
||||
def read_value(model: io.BufferedReader, value_type: GGUFValueType, model_endianness: GGUFEndian) -> Any:
|
||||
value = None
|
||||
def read_value(
|
||||
model: io.BufferedReader, value_type: GGUFValueType, model_endianness: GGUFEndian
|
||||
) -> str | int | float | bool | list:
|
||||
value: Any
|
||||
if value_type in GGUF_NUMBER_FORMATS:
|
||||
value = GGUFInfoParser.read_number(model, value_type, model_endianness)
|
||||
elif value_type == GGUFValueType.BOOL:
|
||||
@@ -158,12 +160,12 @@ class GGUFInfoParser:
|
||||
value = GGUFInfoParser.read_string(model, model_endianness)
|
||||
elif value_type == GGUFValueType.ARRAY:
|
||||
array_type = GGUFInfoParser.read_value_type(model, model_endianness)
|
||||
array_length = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
|
||||
array_length = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
|
||||
value = [GGUFInfoParser.read_value(model, array_type, model_endianness) for _ in range(array_length)]
|
||||
else:
|
||||
raise ParseError(f"Unknown type '{value_type}'")
|
||||
|
||||
if value is not None:
|
||||
return value
|
||||
raise ParseError(f"Unknown type '{value_type}'")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def get_model_endianness(model_path: str) -> GGUFEndian:
|
||||
@@ -176,7 +178,7 @@ class GGUFInfoParser:
|
||||
if magic_number != GGUFModelInfo.MAGIC_NUMBER:
|
||||
raise ParseError(f"Invalid GGUF magic number '{magic_number}'")
|
||||
|
||||
gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
|
||||
gguf_version = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
|
||||
if gguf_version & 0xFFFF == 0x0000:
|
||||
model_endianness = GGUFEndian.BIG
|
||||
|
||||
@@ -191,10 +193,10 @@ class GGUFInfoParser:
|
||||
if magic_number != GGUFModelInfo.MAGIC_NUMBER:
|
||||
raise ParseError(f"Invalid GGUF magic number '{magic_number}'")
|
||||
|
||||
gguf_version = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
|
||||
gguf_version = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
|
||||
|
||||
tensor_count = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
|
||||
metadata_kv_count = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
|
||||
tensor_count = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
|
||||
metadata_kv_count = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
|
||||
|
||||
metadata = {}
|
||||
for _ in range(metadata_kv_count):
|
||||
@@ -205,13 +207,17 @@ class GGUFInfoParser:
|
||||
tensors: list[Tensor] = []
|
||||
for _ in range(tensor_count):
|
||||
name = GGUFInfoParser.read_string(model, model_endianness)
|
||||
n_dimensions = GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness)
|
||||
n_dimensions = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
|
||||
dimensions: list[int] = []
|
||||
for _ in range(n_dimensions):
|
||||
dimensions.append(GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
|
||||
tensor_type = GGML_TYPE(GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
|
||||
offset = GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness)
|
||||
tensors.append(Tensor(name, n_dimensions, dimensions, tensor_type, offset))
|
||||
dim = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
|
||||
dimensions.append(dim)
|
||||
tensor_type = GGML_TYPE(
|
||||
cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT32, model_endianness))
|
||||
)
|
||||
|
||||
offset = cast(int, GGUFInfoParser.read_number(model, GGUFValueType.UINT64, model_endianness))
|
||||
tensors.append(Tensor(name, n_dimensions, dimensions, tensor_type.name, offset))
|
||||
|
||||
return GGUFModelInfo(
|
||||
model_name, model_registry, model_path, gguf_version, metadata, tensors, model_endianness
|
||||
|
||||
@@ -71,7 +71,7 @@ class GlobalModelStore:
|
||||
last_modified = os.path.getmtime(snapshot_file_path)
|
||||
file_size = os.path.getsize(snapshot_file_path)
|
||||
collected_files.append(
|
||||
ModelFile(snapshot_file, last_modified, file_size, is_partially_downloaded)
|
||||
ModelFile(snapshot_file.name, last_modified, file_size, is_partially_downloaded)
|
||||
)
|
||||
models[model_name] = collected_files
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import re
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, cast
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
@@ -298,7 +298,6 @@ def parse_go_template(content: str) -> list[Node]:
|
||||
artificial=False,
|
||||
)
|
||||
root_nodes.append(content_node)
|
||||
|
||||
return root_nodes
|
||||
|
||||
|
||||
@@ -353,7 +352,6 @@ def translate_continue_nodes(root_nodes: list[Node]) -> list[Node]:
|
||||
parent_node.children = parent_node.children[:start_index] + [if_node, end_node]
|
||||
|
||||
find_continue_nodes(root_nodes)
|
||||
|
||||
skip_variable = "$should_continue"
|
||||
for continue_node in continue_nodes:
|
||||
# find start of loop to initialize continue skip variable
|
||||
@@ -376,6 +374,7 @@ def translate_continue_nodes(root_nodes: list[Node]) -> list[Node]:
|
||||
children=[],
|
||||
artificial=True,
|
||||
)
|
||||
for_node.next = cast(Node, for_node.next)
|
||||
for_node.next.prev = initial_set_node
|
||||
for_node.next = initial_set_node
|
||||
for_node.children = [initial_set_node] + for_node.children
|
||||
@@ -553,13 +552,13 @@ def go_to_jinja(content: str) -> str:
|
||||
if m is None:
|
||||
return ""
|
||||
|
||||
if node.open_scope_node.type in [NodeType.IF, NodeType.ELIF, NodeType.ELSE]:
|
||||
if node.open_scope_node.type in [NodeType.IF, NodeType.ELIF, NodeType.ELSE]: # type: ignore
|
||||
return (
|
||||
node.content[: m.start(1)].replace(GO_SYMBOL_OPEN_BRACKETS, JINJA_SYMBOL_OPEN_BRACKETS)
|
||||
+ "endif"
|
||||
+ node.content[m.end(1) :].replace(GO_SYMBOL_CLOSE_BRACKETS, JINJA_SYMBOL_CLOSE_BRACKETS)
|
||||
)
|
||||
elif node.open_scope_node.type == NodeType.RANGE:
|
||||
elif node.open_scope_node.type == NodeType.RANGE: # type: ignore
|
||||
loop_vars.pop()
|
||||
if loop_index_vars:
|
||||
loop_index_vars.pop()
|
||||
|
||||
@@ -15,7 +15,7 @@ class RefFile:
|
||||
CHAT_TEMPLATE_SUFFIX = "chat"
|
||||
MMPROJ_SUFFIX = "mmproj"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.hash: str = ""
|
||||
self.filenames: list[str] = []
|
||||
self.model_name: str = ""
|
||||
@@ -133,7 +133,6 @@ def migrate_reffile_to_refjsonfile(ref_file_path: str, snapshot_directory: str)
|
||||
|
||||
|
||||
class StoreFileType(StrEnum):
|
||||
|
||||
GGUF_MODEL = "gguf"
|
||||
MMPROJ = "mmproj"
|
||||
CHAT_TEMPLATE = "chat_template"
|
||||
@@ -152,7 +151,6 @@ class StoreFileType(StrEnum):
|
||||
|
||||
@dataclass
|
||||
class StoreFile:
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
type: StoreFileType
|
||||
@@ -160,7 +158,6 @@ class StoreFile:
|
||||
|
||||
@dataclass
|
||||
class RefJSONFile:
|
||||
|
||||
hash: str
|
||||
path: str
|
||||
files: list[StoreFile]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from enum import IntEnum
|
||||
from typing import Dict
|
||||
from typing import Dict, Sequence
|
||||
|
||||
from ramalama.common import generate_sha256
|
||||
from ramalama.http_client import download_file
|
||||
@@ -60,7 +60,7 @@ class LocalSnapshotFile(SnapshotFile):
|
||||
):
|
||||
super().__init__(
|
||||
"",
|
||||
"",
|
||||
{},
|
||||
generate_sha256(content),
|
||||
name,
|
||||
type,
|
||||
@@ -77,7 +77,7 @@ class LocalSnapshotFile(SnapshotFile):
|
||||
return os.path.relpath(blob_file_path, start=snapshot_dir)
|
||||
|
||||
|
||||
def validate_snapshot_files(snapshot_files: list[SnapshotFile]):
|
||||
def validate_snapshot_files(snapshot_files: Sequence[SnapshotFile]):
|
||||
chat_template_files = []
|
||||
mmproj_files = []
|
||||
for file in snapshot_files:
|
||||
|
||||
@@ -4,7 +4,7 @@ import urllib.error
|
||||
from collections import Counter
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import ramalama.model_store.go2jinja as go2jinja
|
||||
from ramalama.common import perror, sanitize_filename, verify_checksum
|
||||
@@ -97,9 +97,12 @@ class ModelStore:
|
||||
return RefJSONFile.from_path(ref_file_path)
|
||||
|
||||
def update_ref_file(
|
||||
self, model_tag: str, snapshot_hash: str = "", snapshot_files: list[SnapshotFile] = []
|
||||
self, model_tag: str, snapshot_hash: str = "", snapshot_files: Optional[list[SnapshotFile]] = None
|
||||
) -> Optional[RefJSONFile]:
|
||||
ref_file: RefJSONFile = self.get_ref_file(model_tag)
|
||||
if snapshot_files is None:
|
||||
snapshot_files = []
|
||||
|
||||
ref_file: RefJSONFile | None = self.get_ref_file(model_tag)
|
||||
if ref_file is None:
|
||||
return None
|
||||
|
||||
@@ -154,9 +157,9 @@ class ModelStore:
|
||||
)
|
||||
|
||||
def get_cached_files(self, model_tag: str) -> Tuple[str, list[str], bool]:
|
||||
cached_files = []
|
||||
cached_files: list[str] = []
|
||||
|
||||
ref_file: RefJSONFile = self.get_ref_file(model_tag)
|
||||
ref_file: RefJSONFile | None = self.get_ref_file(model_tag)
|
||||
if ref_file is None:
|
||||
return ("", cached_files, False)
|
||||
|
||||
@@ -182,8 +185,10 @@ class ModelStore:
|
||||
snapshot_directory = self.get_snapshot_directory(snapshot_hash)
|
||||
os.makedirs(snapshot_directory, exist_ok=True)
|
||||
|
||||
def _download_snapshot_files(self, model_tag: str, snapshot_hash: str, snapshot_files: list[SnapshotFile]):
|
||||
ref_file = self.get_ref_file(model_tag)
|
||||
def _download_snapshot_files(self, model_tag: str, snapshot_hash: str, snapshot_files: Sequence[SnapshotFile]):
|
||||
ref_file: None | RefJSONFile = self.get_ref_file(model_tag)
|
||||
if ref_file is None:
|
||||
raise ValueError("Cannot download snapshots without a valid ref file.")
|
||||
|
||||
for file in snapshot_files:
|
||||
dest_path = self.get_blob_file_path(file.hash)
|
||||
@@ -223,8 +228,8 @@ class ModelStore:
|
||||
if file.type == SnapshotFileType.ChatTemplate:
|
||||
chat_template_file_path = self.get_blob_file_path(file.hash)
|
||||
chat_template = ""
|
||||
with open(chat_template_file_path, "r") as file:
|
||||
chat_template = file.read()
|
||||
with open(chat_template_file_path, "r") as template_file:
|
||||
chat_template = template_file.read()
|
||||
|
||||
if not go2jinja.is_go_template(chat_template):
|
||||
return
|
||||
@@ -321,7 +326,7 @@ class ModelStore:
|
||||
self.remove_snapshot(model_tag)
|
||||
raise ex
|
||||
|
||||
def update_snapshot(self, model_tag: str, snapshot_hash: str, new_snapshot_files: list[SnapshotFile]) -> bool:
|
||||
def update_snapshot(self, model_tag: str, snapshot_hash: str, new_snapshot_files: Sequence[SnapshotFile]) -> bool:
|
||||
validate_snapshot_files(new_snapshot_files)
|
||||
snapshot_hash = sanitize_filename(snapshot_hash)
|
||||
|
||||
@@ -363,7 +368,7 @@ class ModelStore:
|
||||
for entry in os.listdir(self.refs_directory)
|
||||
if os.path.isfile(os.path.join(self.refs_directory, entry))
|
||||
]
|
||||
refs = [self.get_ref_file(tag) for tag in model_tags]
|
||||
refs = [ref for tag in model_tags if (ref := self.get_ref_file(tag))]
|
||||
|
||||
blob_refcounts = Counter(file.name for ref in refs for file in ref.files)
|
||||
|
||||
|
||||
@@ -130,10 +130,11 @@ def list_models(args: EngineArgType):
|
||||
|
||||
|
||||
class OCI(Model):
|
||||
def __init__(self, model, model_store_path, conman, ignore_stderr=False):
|
||||
type = "OCI"
|
||||
|
||||
def __init__(self, model: str, model_store_path: str, conman: str, ignore_stderr: bool = False):
|
||||
super().__init__(model, model_store_path)
|
||||
|
||||
self.type = "OCI"
|
||||
if not conman:
|
||||
raise ValueError("RamaLama OCI Images requires a container engine")
|
||||
|
||||
|
||||
@@ -139,12 +139,12 @@ class OllamaRepository:
|
||||
|
||||
|
||||
class Ollama(Model):
|
||||
def __init__(self, model, model_store_path):
|
||||
def __init__(self, model, model_store_path) -> None:
|
||||
super().__init__(model, model_store_path)
|
||||
|
||||
self.type = "Ollama"
|
||||
|
||||
def extract_model_identifiers(self):
|
||||
def extract_model_identifiers(self) -> tuple[str, str, str]:
|
||||
model_name, model_tag, model_organization = super().extract_model_identifiers()
|
||||
|
||||
# use the ollama default namespace if no model organization has been identified
|
||||
@@ -152,11 +152,11 @@ class Ollama(Model):
|
||||
model_organization = "library"
|
||||
return model_name, model_tag, model_organization
|
||||
|
||||
def resolve_model(self):
|
||||
def resolve_model(self) -> str:
|
||||
name, tag, organization = self.extract_model_identifiers()
|
||||
return f"ollama://{organization}/{name}:{tag}"
|
||||
|
||||
def pull(self, args):
|
||||
def pull(self, args) -> None:
|
||||
name, tag, organization = self.extract_model_identifiers()
|
||||
_, cached_files, all = self.model_store.get_cached_files(tag)
|
||||
if all:
|
||||
|
||||
@@ -12,17 +12,17 @@ INPUT_DIR = "/docs"
|
||||
|
||||
|
||||
class Rag:
|
||||
model = ""
|
||||
target = ""
|
||||
urls = []
|
||||
model: str = ""
|
||||
target: str = ""
|
||||
urls: list[str] = []
|
||||
|
||||
def __init__(self, target):
|
||||
def __init__(self, target: str):
|
||||
if not target.islower():
|
||||
raise ValueError(f"invalid reference format: repository name '{target}' must be lowercase")
|
||||
self.target = target
|
||||
set_accel_env_vars()
|
||||
|
||||
def build(self, source, target, args):
|
||||
def build(self, source: str, target: str, args):
|
||||
perror(f"\nBuilding {target} ...")
|
||||
contextdir = os.path.dirname(source)
|
||||
src = os.path.basename(source)
|
||||
@@ -60,7 +60,7 @@ COPY {src} /vector.db
|
||||
)
|
||||
return imageid
|
||||
|
||||
def _handle_paths(self, path):
|
||||
def _handle_paths(self, path: str):
|
||||
"""Adds a volume mount if path exists, otherwise add URL."""
|
||||
parsed = urlparse(path)
|
||||
if parsed.scheme in ["file", ""] and parsed.netloc == "":
|
||||
@@ -135,7 +135,7 @@ COPY {src} /vector.db
|
||||
shutil.rmtree(ragdb.name, ignore_errors=True)
|
||||
|
||||
|
||||
def rag_image(image) -> str:
|
||||
def rag_image(image: str) -> str:
|
||||
imagespec = image.split(":")
|
||||
rag_image = f"{imagespec[0]}-rag"
|
||||
if len(imagespec) > 1:
|
||||
|
||||
@@ -104,7 +104,7 @@ class URL(Model):
|
||||
|
||||
return files
|
||||
|
||||
def pull(self, _):
|
||||
def pull(self, _) -> None:
|
||||
name, tag, _ = self.extract_model_identifiers()
|
||||
_, _, all_files = self.model_store.get_cached_files(tag)
|
||||
if all_files:
|
||||
|
||||
@@ -313,6 +313,7 @@ class TestOpanAIChatAPIMessageBuilder:
|
||||
assert isinstance(messages[1]["content"], list)
|
||||
assert len(messages[1]["content"]) == 1
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
|
||||
def test_builder_load_no_supported_files(self):
|
||||
"""Test loading directory with no supported files."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
@@ -70,6 +70,7 @@ class TestFileUploadChatIntegration:
|
||||
assert "readme.md" in system_message["content"]
|
||||
assert "<!--start_document" in system_message["content"]
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_chat_with_file_input_no_files(self, mock_urlopen):
|
||||
"""Test chat functionality with input directory containing no supported files."""
|
||||
@@ -385,6 +386,7 @@ class TestImageUploadChatIntegration:
|
||||
for item in image_msg["content"]
|
||||
)
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_chat_with_image_input_unsupported_image_types(self, mock_urlopen):
|
||||
"""Test chat functionality with unsupported image file types."""
|
||||
|
||||
@@ -182,6 +182,7 @@ class TestFileUploadWithDataFiles:
|
||||
assert "sample.md" in content
|
||||
assert "sample.json" in content
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
|
||||
def test_unsupported_file_handling(self, data_dir):
|
||||
"""Test that unsupported files are handled correctly."""
|
||||
|
||||
@@ -310,6 +311,7 @@ class TestImageUploadWithDataFiles:
|
||||
assert all("data:image/" in item["image_url"]["url"] for item in messages[0]["content"])
|
||||
assert all("base64," in item["image_url"]["url"] for item in messages[0]["content"])
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
|
||||
def test_image_unsupported_file_handling(self, data_dir):
|
||||
"""Test that unsupported image files are handled correctly."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
Reference in New Issue
Block a user