mirror of
https://github.com/containers/ramalama.git
synced 2026-02-05 06:46:39 +01:00
Fix quadlet generation for multi-part models
Add support for generating multiple Mount= declarations in quadlet files for multi-part models (e.g., models split into multiple GGUF files). Changes: - Add _get_all_model_part_paths() method to Transport class to retrieve all parts of a multi-part model from the ref file - Update Quadlet class to accept and store model_parts list - Modify _gen_model_volume() to generate Mount= entries for each part - Update generate_container_config() to pass model parts to quadlet - Add test case for multi-part model quadlet generation The fix ensures that when generating quadlet files for models like gpt-oss-120b that are split across multiple files, all parts are properly mounted into the container with correct src/dest paths. Fixes #2017 Signed-off-by: Daniel J Walsh <dwalsh@redhat.com>
This commit is contained in:
@@ -16,6 +16,7 @@ class Quadlet:
|
||||
args,
|
||||
exec_args,
|
||||
artifact: bool,
|
||||
model_parts: Optional[list[Tuple[str, str]]] = None,
|
||||
):
|
||||
self.src_model_path, self.dest_model_path = model_paths
|
||||
self.src_chat_template_path, self.dest_chat_template_path = (
|
||||
@@ -23,6 +24,9 @@ class Quadlet:
|
||||
)
|
||||
self.src_mmproj_path, self.dest_mmproj_path = mmproj_path if mmproj_path is not None else ("", "")
|
||||
|
||||
# Store all model parts for multi-part models
|
||||
self.model_parts = model_parts if model_parts is not None else [(self.src_model_path, self.dest_model_path)]
|
||||
|
||||
if self.src_model_path.startswith("oci://"):
|
||||
self.src_model_path = self.src_model_path.removeprefix("oci://")
|
||||
self.ai_image = self.src_model_path
|
||||
@@ -133,12 +137,16 @@ class Quadlet:
|
||||
def _gen_model_volume(self, quadlet_file: UnitFile):
|
||||
files: list[UnitFile] = []
|
||||
|
||||
if os.path.exists(self.src_model_path):
|
||||
quadlet_file.add(
|
||||
"Container", "Mount", f"type=bind,src={self.src_model_path},target={self.dest_model_path},ro,Z"
|
||||
)
|
||||
# Check if any model part exists as a file (non-OCI model from store)
|
||||
local_model_parts = [part for part in self.model_parts if os.path.exists(part[0])]
|
||||
|
||||
if local_model_parts:
|
||||
# Generate Mount= entries for each model part
|
||||
for src_path, dest_path in local_model_parts:
|
||||
quadlet_file.add("Container", "Mount", f"type=bind,src={src_path},target={dest_path},ro,Z")
|
||||
return files
|
||||
|
||||
# OCI model handling
|
||||
volume_file_name = f"{self.name}.volume"
|
||||
print(f"Generating quadlet file: {volume_file_name} ")
|
||||
|
||||
|
||||
@@ -197,6 +197,55 @@ class Transport(TransportBase):
|
||||
self._model_store = ModelStore(GlobalModelStore(self._model_store_path), name, self.model_type, orga)
|
||||
return self._model_store
|
||||
|
||||
def _get_all_model_part_paths(
|
||||
self, use_container: bool, should_generate: bool, dry_run: bool
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Returns a list of (src_path, dest_path) tuples for all parts of a model.
|
||||
For single-file models, returns a list with one tuple.
|
||||
For multi-part models, returns a tuple for each part.
|
||||
"""
|
||||
if dry_run:
|
||||
return [("/path/to/model", f"{MNT_DIR}/model.file")]
|
||||
|
||||
if self.model_type == 'oci':
|
||||
# OCI models don't use this path for multi-part handling
|
||||
entry_path_src = self._get_entry_model_path(False, False, False)
|
||||
entry_path_dest = self._get_entry_model_path(True, True, False)
|
||||
return [(entry_path_src, entry_path_dest)]
|
||||
|
||||
ref_file = self.model_store.get_ref_file(self.model_tag)
|
||||
if ref_file is None:
|
||||
raise NoRefFileFound(self.model)
|
||||
|
||||
gguf_files = ref_file.model_files
|
||||
safetensor_files = ref_file.safetensor_model_files
|
||||
if safetensor_files:
|
||||
# Safetensor models use directory mounts, not individual files
|
||||
src_path = self.model_store.get_snapshot_directory_from_tag(self.model_tag)
|
||||
if use_container or should_generate:
|
||||
dest_path = MNT_DIR
|
||||
else:
|
||||
dest_path = src_path
|
||||
return [(src_path, dest_path)]
|
||||
elif not gguf_files:
|
||||
raise NoGGUFModelFileFound()
|
||||
|
||||
model_parts = []
|
||||
for model_file in gguf_files:
|
||||
if use_container or should_generate:
|
||||
dest_path = f"{MNT_DIR}/{model_file.name}"
|
||||
else:
|
||||
dest_path = self.model_store.get_blob_file_path(model_file.hash)
|
||||
src_path = self.model_store.get_blob_file_path(model_file.hash)
|
||||
model_parts.append((src_path, dest_path))
|
||||
|
||||
# Sort multi-part models by filename to ensure correct order
|
||||
if len(model_parts) > 1 and any("-00001-of-" in name for _, name in model_parts):
|
||||
model_parts.sort(key=lambda x: x[1])
|
||||
|
||||
return model_parts
|
||||
|
||||
def _get_entry_model_path(self, use_container: bool, should_generate: bool, dry_run: bool) -> str:
|
||||
"""
|
||||
Returns the path to the model blob on the host if use_container and should_generate are both False.
|
||||
@@ -599,6 +648,9 @@ class Transport(TransportBase):
|
||||
chat_template_dest_path = self._get_chat_template_path(True, True, args.dryrun)
|
||||
mmproj_dest_path = self._get_mmproj_path(True, True, args.dryrun)
|
||||
|
||||
# Get all model parts (for multi-part models)
|
||||
model_parts = self._get_all_model_part_paths(False, True, args.dryrun)
|
||||
|
||||
if args.generate.gen_type == "quadlet":
|
||||
self.quadlet(
|
||||
(model_src_path, model_dest_path),
|
||||
@@ -607,6 +659,7 @@ class Transport(TransportBase):
|
||||
args,
|
||||
exec_args,
|
||||
args.generate.output_dir,
|
||||
model_parts,
|
||||
)
|
||||
elif args.generate.gen_type == "kube":
|
||||
self.kube(
|
||||
@@ -625,6 +678,7 @@ class Transport(TransportBase):
|
||||
args,
|
||||
exec_args,
|
||||
args.generate.output_dir,
|
||||
model_parts,
|
||||
)
|
||||
elif args.generate.gen_type == "compose":
|
||||
self.compose(
|
||||
@@ -664,18 +718,22 @@ class Transport(TransportBase):
|
||||
self._cleanup_server_process(args.server_process)
|
||||
raise e
|
||||
|
||||
def quadlet(self, model_paths, chat_template_paths, mmproj_paths, args, exec_args, output_dir):
|
||||
def quadlet(self, model_paths, chat_template_paths, mmproj_paths, args, exec_args, output_dir, model_parts=None):
|
||||
quadlet = Quadlet(
|
||||
self.model_name, model_paths, chat_template_paths, mmproj_paths, args, exec_args, self.artifact
|
||||
self.model_name, model_paths, chat_template_paths, mmproj_paths, args, exec_args, self.artifact, model_parts
|
||||
)
|
||||
for generated_file in quadlet.generate():
|
||||
generated_file.write(output_dir)
|
||||
|
||||
def quadlet_kube(self, model_paths, chat_template_paths, mmproj_paths, args, exec_args, output_dir):
|
||||
def quadlet_kube(
|
||||
self, model_paths, chat_template_paths, mmproj_paths, args, exec_args, output_dir, model_parts=None
|
||||
):
|
||||
kube = Kube(self.model_name, model_paths, chat_template_paths, mmproj_paths, args, exec_args, self.artifact)
|
||||
kube.generate().write(output_dir)
|
||||
|
||||
quadlet = Quadlet(kube.name, model_paths, chat_template_paths, mmproj_paths, args, exec_args, self.artifact)
|
||||
quadlet = Quadlet(
|
||||
kube.name, model_paths, chat_template_paths, mmproj_paths, args, exec_args, self.artifact, model_parts
|
||||
)
|
||||
quadlet.kube().write(output_dir)
|
||||
|
||||
def kube(self, model_paths, chat_template_paths, mmproj_paths, args, exec_args, output_dir):
|
||||
|
||||
23
test/unit/data/test_quadlet/multipart/gpt-oss-120b.container
Normal file
23
test/unit/data/test_quadlet/multipart/gpt-oss-120b.container
Normal file
@@ -0,0 +1,23 @@
|
||||
[Unit]
|
||||
Description=RamaLama gpt-oss-120b AI Model Service
|
||||
After=local-fs.target
|
||||
|
||||
[Container]
|
||||
AddDevice=-/dev/accel
|
||||
AddDevice=-/dev/dri
|
||||
AddDevice=-/dev/kfd
|
||||
AddDevice=nvidia.com/gpu=all
|
||||
Image=testimage
|
||||
RunInit=true
|
||||
Environment=HOME=/tmp
|
||||
Exec=
|
||||
SecurityLabelDisable=true
|
||||
DropCapability=all
|
||||
NoNewPrivileges=true
|
||||
Mount=type=bind,src=sha256-e2865eb6c1df7b2ffbebf305cd5d9074d5ccc0fe3b862f98d343a46dad1606f9,target=/mnt/models/gpt-oss-120b-mxfp4-00001-of-00003.gguf,ro,Z
|
||||
Mount=type=bind,src=sha256-81856b5b996da9c9fd68397d49671264ead380a8355b3c83284eae5e21e998ed,target=/mnt/models/gpt-oss-120b-mxfp4-00002-of-00003.gguf,ro,Z
|
||||
Mount=type=bind,src=sha256-38b087fffe4b5ba5d62fa7761ed7278a07fef7a6145b9744b11205b851021dce,target=/mnt/models/gpt-oss-120b-mxfp4-00003-of-00003.gguf,ro,Z
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target default.target
|
||||
|
||||
@@ -48,6 +48,7 @@ class Input:
|
||||
exec_args: list = [],
|
||||
accel_type: str = "cuda",
|
||||
artifact: str = "",
|
||||
model_parts: Optional[list] = None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model_src_blob = model_src_blob
|
||||
@@ -64,6 +65,7 @@ class Input:
|
||||
self.exec_args = exec_args
|
||||
self.accel_type = accel_type
|
||||
self.artifact = artifact
|
||||
self.model_parts = model_parts
|
||||
|
||||
|
||||
DATA_PATH = Path(__file__).parent / "data" / "test_quadlet"
|
||||
@@ -203,6 +205,30 @@ DATA_PATH = Path(__file__).parent / "data" / "test_quadlet"
|
||||
),
|
||||
DATA_PATH / "oci_rag",
|
||||
),
|
||||
(
|
||||
Input(
|
||||
model_name="gpt-oss-120b",
|
||||
model_src_blob="sha256-e2865eb6c1df7b2ffbebf305cd5d9074d5ccc0fe3b862f98d343a46dad1606f9",
|
||||
model_dest_name="/mnt/models/gpt-oss-120b-mxfp4-00001-of-00003.gguf",
|
||||
image="testimage",
|
||||
model_file_exists=True,
|
||||
model_parts=[
|
||||
(
|
||||
"sha256-e2865eb6c1df7b2ffbebf305cd5d9074d5ccc0fe3b862f98d343a46dad1606f9",
|
||||
"/mnt/models/gpt-oss-120b-mxfp4-00001-of-00003.gguf",
|
||||
),
|
||||
(
|
||||
"sha256-81856b5b996da9c9fd68397d49671264ead380a8355b3c83284eae5e21e998ed",
|
||||
"/mnt/models/gpt-oss-120b-mxfp4-00002-of-00003.gguf",
|
||||
),
|
||||
(
|
||||
"sha256-38b087fffe4b5ba5d62fa7761ed7278a07fef7a6145b9744b11205b851021dce",
|
||||
"/mnt/models/gpt-oss-120b-mxfp4-00003-of-00003.gguf",
|
||||
),
|
||||
],
|
||||
),
|
||||
DATA_PATH / "multipart",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_quadlet_generate(input: Input, expected_files_path: Path, monkeypatch):
|
||||
@@ -217,6 +243,11 @@ def test_quadlet_generate(input: Input, expected_files_path: Path, monkeypatch):
|
||||
input.mmproj_src_blob: input.mmproj_file_exists,
|
||||
}
|
||||
|
||||
# Add existence checks for all model parts
|
||||
if input.model_parts:
|
||||
for src, _ in input.model_parts:
|
||||
existence[src] = input.model_file_exists
|
||||
|
||||
monkeypatch.setattr("os.path.exists", lambda path: existence.get(path, False))
|
||||
monkeypatch.setattr(Quadlet, "_gen_env", lambda self, quadlet_file: None)
|
||||
monkeypatch.setattr("ramalama.quadlet.get_accel", lambda: input.accel_type)
|
||||
@@ -229,6 +260,7 @@ def test_quadlet_generate(input: Input, expected_files_path: Path, monkeypatch):
|
||||
input.args,
|
||||
input.exec_args,
|
||||
input.artifact,
|
||||
input.model_parts,
|
||||
).generate():
|
||||
assert file.filename in expected_files
|
||||
|
||||
|
||||
Reference in New Issue
Block a user