mirror of
https://github.com/containers/ramalama.git
synced 2026-02-05 06:46:39 +01:00
add an option to disable verification of models after they're pulled
This allows models of different endianness to be pulled and inspected. Handle endianness mismatches more gracefully in the cli. Signed-off-by: Mike Bonnet <mikeb@redhat.com>
This commit is contained in:
@@ -20,6 +20,9 @@ Print usage message
|
||||
#### **--tls-verify**=*true*
|
||||
require HTTPS and verify certificates when contacting OCI registries
|
||||
|
||||
#### **--verify**=*true*
|
||||
verify the model after pull, disable to allow pulling of models with different endianness
|
||||
|
||||
## SEE ALSO
|
||||
**[ramalama(1)](ramalama.1.md)**
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from ramalama import engine
|
||||
from ramalama.chat import default_prefix
|
||||
from ramalama.common import accel_image, get_accel, perror
|
||||
from ramalama.config import CONFIG, coerce_to_bool, load_file_config
|
||||
from ramalama.endian import EndianMismatchError
|
||||
from ramalama.logger import configure_logger, logger
|
||||
from ramalama.model import (
|
||||
MODEL_TYPES,
|
||||
@@ -626,6 +627,12 @@ def pull_parser(subparsers):
|
||||
default=True,
|
||||
help="require HTTPS and verify certificates when contacting registries",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify",
|
||||
default=CONFIG.verify,
|
||||
action=CoerceToBool,
|
||||
help="verify the model after pull, disable to allow pulling of models with different endianness",
|
||||
)
|
||||
parser.add_argument("MODEL", completer=suppressCompleter) # positional argument
|
||||
parser.set_defaults(func=pull_cli)
|
||||
|
||||
@@ -1427,6 +1434,8 @@ def main():
|
||||
eprint(e, errno.ENOSYS)
|
||||
except subprocess.CalledProcessError as e:
|
||||
eprint(e, e.returncode)
|
||||
except EndianMismatchError:
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(0)
|
||||
except IOError as e:
|
||||
|
||||
@@ -106,6 +106,7 @@ class BaseConfig:
|
||||
threads: int = -1
|
||||
transport: str = "ollama"
|
||||
user: UserConfig = field(default_factory=UserConfig)
|
||||
verify: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.container = coerce_to_bool(self.container) if self.container is not None else self.engine is not None
|
||||
@@ -189,7 +190,7 @@ def load_env_config(env: Mapping[str, str] | None = None) -> dict[str, Any]:
|
||||
if 'images' in config:
|
||||
config['images'] = json.loads(config['images'])
|
||||
|
||||
for key in ['ocr', 'keep_groups', 'container']:
|
||||
for key in ['ocr', 'keep_groups', 'container', 'verify']:
|
||||
if key in config:
|
||||
config[key] = coerce_to_bool(config[key])
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from ramalama.common import (
|
||||
perror,
|
||||
run_cmd,
|
||||
)
|
||||
from ramalama.endian import EndianMismatchError
|
||||
from ramalama.logger import logger
|
||||
from ramalama.model import Model
|
||||
from ramalama.model_store.snapshot_file import SnapshotFile, SnapshotFileType
|
||||
@@ -264,8 +265,11 @@ class HFStyleRepoModel(Model, ABC):
|
||||
repo = self.create_repository(name, organization, tag)
|
||||
snapshot_hash = repo.model_hash
|
||||
files = repo.get_file_list(cached_files)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files, verify=getattr(args, "verify", True))
|
||||
|
||||
except EndianMismatchError:
|
||||
# No use pulling again
|
||||
raise
|
||||
except Exception as e:
|
||||
if not available(self.get_cli_command()):
|
||||
perror(f"URL pull failed and {self.get_cli_command()} not available")
|
||||
@@ -278,7 +282,7 @@ class HFStyleRepoModel(Model, ABC):
|
||||
run_cmd(conman_args)
|
||||
|
||||
snapshot_hash, files = self._collect_cli_files(tempdir)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files, verify=getattr(args, "verify", True))
|
||||
|
||||
def exec(self, cmd_args, args):
|
||||
try:
|
||||
|
||||
@@ -309,7 +309,7 @@ class ModelStore:
|
||||
self._verify_endianness(model_tag)
|
||||
self._store.verify_snapshot()
|
||||
|
||||
def new_snapshot(self, model_tag: str, snapshot_hash: str, snapshot_files: list[SnapshotFile]):
|
||||
def new_snapshot(self, model_tag: str, snapshot_hash: str, snapshot_files: list[SnapshotFile], verify: bool = True):
|
||||
snapshot_hash = sanitize_filename(snapshot_hash)
|
||||
|
||||
try:
|
||||
@@ -327,7 +327,8 @@ class ModelStore:
|
||||
raise ex
|
||||
|
||||
try:
|
||||
self.verify_snapshot(model_tag)
|
||||
if verify:
|
||||
self.verify_snapshot(model_tag)
|
||||
except EndianMismatchError as ex:
|
||||
perror(f"Verification of snapshot failed: {ex}")
|
||||
perror("Removing snapshot...")
|
||||
|
||||
@@ -188,7 +188,7 @@ class Ollama(Model):
|
||||
self.print_pull_message(f"ollama://{organization}/{name}:{tag}")
|
||||
|
||||
model_hash = ollama_repo.get_model_hash(manifest)
|
||||
self.model_store.new_snapshot(tag, model_hash, files)
|
||||
self.model_store.new_snapshot(tag, model_hash, files, verify=getattr(args, "verify", True))
|
||||
|
||||
# If a model has been downloaded via ollama cli, only create symlink in the snapshots directory
|
||||
if is_model_in_ollama_cache:
|
||||
|
||||
@@ -104,7 +104,7 @@ class URL(Model):
|
||||
|
||||
return files
|
||||
|
||||
def pull(self, _) -> None:
|
||||
def pull(self, args) -> None:
|
||||
name, tag, _ = self.extract_model_identifiers()
|
||||
_, _, all_files = self.model_store.get_cached_files(tag)
|
||||
if all_files:
|
||||
@@ -122,12 +122,12 @@ class URL(Model):
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files, verify=getattr(args, "verify", True))
|
||||
return
|
||||
|
||||
if is_split_file_model(self.model):
|
||||
files = self._assemble_split_file_list(snapshot_hash)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files, verify=getattr(args, "verify", True))
|
||||
return
|
||||
|
||||
files.append(
|
||||
@@ -141,5 +141,5 @@ class URL(Model):
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files)
|
||||
self.model_store.new_snapshot(tag, snapshot_hash, files, verify=getattr(args, "verify", True))
|
||||
return
|
||||
|
||||
@@ -167,6 +167,22 @@ load setup_suite
|
||||
run_ramalama rm oci://quay.io/mmortari/gguf-py-example:v1
|
||||
}
|
||||
|
||||
@test "ramalama pull little-endian" {
|
||||
if ! is_bigendian; then
|
||||
skip "Testing pulls of opposite-endian models"
|
||||
fi
|
||||
run_ramalama 1 pull --verify=on tiny
|
||||
is "$output" ".*Endian mismatch of host (BIG) and model (LITTLE).*" "detected little-endian model"
|
||||
}
|
||||
|
||||
@test "ramalama pull big-endian" {
|
||||
if is_bigendian; then
|
||||
skip "Testing pulls of opposite-endian models"
|
||||
fi
|
||||
run_ramalama 1 pull --verify=on granite-be-3.0:1b
|
||||
is "$output" ".*Endian mismatch of host (LITTLE) and model (BIG).*" "detected big-endian model"
|
||||
}
|
||||
|
||||
@test "ramalama URL" {
|
||||
model=$RAMALAMA_TMPDIR/mymodel.gguf
|
||||
touch $model
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
[ramalama]
|
||||
|
||||
pull="missing"
|
||||
|
||||
verify=false
|
||||
|
||||
@@ -150,3 +150,28 @@ def test_main_doesnt_crash_on_exc(monkeypatch, exc_type):
|
||||
with pytest.raises(SystemExit):
|
||||
with mock.patch("ramalama.cli.inspect_cli", side_effect=exc_type):
|
||||
main()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"option, value",
|
||||
[
|
||||
(None, True),
|
||||
("yes", True),
|
||||
("on", True),
|
||||
("1", True),
|
||||
("no", False),
|
||||
("off", False),
|
||||
("0", False),
|
||||
],
|
||||
)
|
||||
def test_pull_verify(monkeypatch, option, value):
|
||||
from ramalama.cli import init_cli
|
||||
|
||||
argv = ["ramalama", "pull"]
|
||||
if option:
|
||||
argv.append(f"--verify={option}")
|
||||
argv.append("model")
|
||||
monkeypatch.setattr(sys, "argv", argv)
|
||||
parser, args = init_cli()
|
||||
assert hasattr(args, "verify")
|
||||
assert args.verify == value
|
||||
|
||||
@@ -30,6 +30,7 @@ def test_correct_config_defaults(monkeypatch):
|
||||
assert cfg.temp == "0.8"
|
||||
assert cfg.transport == "ollama"
|
||||
assert cfg.ocr is False
|
||||
assert cfg.verify is True
|
||||
|
||||
|
||||
def test_config_defaults_not_set(monkeypatch):
|
||||
@@ -55,6 +56,7 @@ def test_config_defaults_not_set(monkeypatch):
|
||||
assert cfg.is_set("temp") is False
|
||||
assert cfg.is_set("transport") is False
|
||||
assert cfg.is_set("ocr") is False
|
||||
assert cfg.is_set("verify") is False
|
||||
|
||||
|
||||
def test_file_config_overrides_defaults():
|
||||
@@ -62,6 +64,7 @@ def test_file_config_overrides_defaults():
|
||||
"image": "custom/image:latest",
|
||||
"threads": 8,
|
||||
"container": False,
|
||||
"verify": False,
|
||||
}
|
||||
|
||||
with patch("ramalama.config.load_file_config", return_value=mock_file_config):
|
||||
@@ -70,10 +73,12 @@ def test_file_config_overrides_defaults():
|
||||
assert cfg.image == "custom/image:latest"
|
||||
assert cfg.threads == 8
|
||||
assert cfg.container is False
|
||||
assert cfg.verify is False
|
||||
|
||||
assert cfg.is_set("image") is True
|
||||
assert cfg.is_set("threads") is True
|
||||
assert cfg.is_set("container") is True
|
||||
assert cfg.is_set("verify") is True
|
||||
|
||||
|
||||
def test_env_overrides_file_and_default():
|
||||
@@ -183,6 +188,7 @@ class TestLoadEnvConfig:
|
||||
"RAMALAMA_THREADS": "8",
|
||||
"RAMALAMA_CONTAINER": "true",
|
||||
"RAMALAMA_HOST": "127.0.0.1",
|
||||
"RAMALAMA_VERIFY": "false",
|
||||
}
|
||||
|
||||
result = load_env_config(env)
|
||||
@@ -192,6 +198,7 @@ class TestLoadEnvConfig:
|
||||
"threads": 8,
|
||||
"container": True,
|
||||
"host": "127.0.0.1",
|
||||
"verify": False,
|
||||
}
|
||||
assert result == expected
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -43,7 +43,7 @@ class OllamaRepositoryMock(OllamaRepository):
|
||||
return [LocalSnapshotFile("dummy content", "dummy", SnapshotFileType.Other)]
|
||||
|
||||
|
||||
def test_ollama_model_pull(ollama_model, args):
|
||||
def test_ollama_model_pull(ollama_model):
|
||||
args.quiet = True
|
||||
with patch("ramalama.ollama.OllamaRepository", return_value=OllamaRepositoryMock("dummy-model")):
|
||||
ollama_model.pull(args)
|
||||
ollama_model.pull(Mock(verify=True))
|
||||
|
||||
Reference in New Issue
Block a user