1
0
mirror of https://github.com/containers/ramalama.git synced 2026-02-05 06:46:39 +01:00

Merge pull request #2359 from olliewalsh/flash_attn

This commit is contained in:
Mike Bonnet
2026-01-29 09:37:30 -08:00
committed by GitHub
7 changed files with 13 additions and 30 deletions

View File

@@ -55,10 +55,6 @@ commands:
- name: "--no-webui"
description: "Disable the Web UI"
if: "{{ args.webui == 'off' }}"
- name: "--flash-attn"
description: "Set Flash Attention use"
value: "on"
if: "{{ host.uses_nvidia or host.uses_metal }}"
- name: "-ngl"
description: "Number of layers to offload to the GPU if available"
value: "{{ 999 if args.ngl < 0 else args.ngl }}"

View File

@@ -2,12 +2,13 @@ import argparse
import os
from typing import Optional
from ramalama.common import check_metal, check_nvidia, get_accel_env_vars
from ramalama.common import check_metal, check_nvidia
from ramalama.console import should_colorize
from ramalama.transports.transport_factory import CLASS_MODEL_TYPES, New
class RamalamaArgsContext:
def __init__(self) -> None:
self.cache_reuse: Optional[int] = None
self.container: Optional[bool] = None
@@ -51,6 +52,7 @@ class RamalamaArgsContext:
class RamalamaRagGenArgsContext:
def __init__(self) -> None:
self.debug: bool | None = None
self.format: str | None = None
@@ -72,6 +74,7 @@ class RamalamaRagGenArgsContext:
class RamalamaRagArgsContext:
def __init__(self) -> None:
self.debug: bool | None = None
self.port: str | None = None
@@ -89,6 +92,7 @@ class RamalamaRagArgsContext:
class RamalamaModelContext:
def __init__(self, model: CLASS_MODEL_TYPES, is_container: bool, should_generate: bool, dry_run: bool):
self.model = model
self.is_container = is_container
@@ -124,6 +128,7 @@ class RamalamaModelContext:
class RamalamaHostContext:
def __init__(
self, is_container: bool, uses_nvidia: bool, uses_metal: bool, should_colorize: bool, rpc_nodes: Optional[str]
):
@@ -135,6 +140,7 @@ class RamalamaHostContext:
class RamalamaCommandContext:
def __init__(
self,
args: RamalamaArgsContext | RamalamaRagGenArgsContext | RamalamaRagArgsContext,
@@ -164,12 +170,9 @@ class RamalamaCommandContext:
else:
model = None
skip_gpu_probe = should_generate or bool(get_accel_env_vars())
uses_nvidia = True if skip_gpu_probe else (check_nvidia() is None)
host = RamalamaHostContext(
is_container,
uses_nvidia,
check_nvidia() is not None,
check_metal(argparse.Namespace(**{"container": is_container})),
should_colorize(),
os.getenv("RAMALAMA_LLAMACPP_RPC_NODES", None),

View File

@@ -16,7 +16,6 @@ from test.conftest import (
skip_if_docker,
skip_if_gh_actions_darwin,
skip_if_no_container,
skip_if_not_darwin,
skip_if_ppc64le,
skip_if_s390x,
)
@@ -182,10 +181,6 @@ def test_basic_dry_run():
[], r".*--cache-reuse 256", None, None, True,
id="check --cache-reuse default value", marks=skip_if_container
),
pytest.param(
[], r".*--flash-attn", None, None, True,
id="check --flash-attn", marks=[skip_if_container, skip_if_not_darwin]
),
pytest.param(
["--host", "127.0.0.1"],
r".*--host 127.0.0.1", None, None, True,

View File

@@ -73,9 +73,6 @@ verify_begin=".*run --rm"
run_ramalama -q --dryrun serve ${model}
assert "$output" =~ ".*--host 0.0.0.0" "Outside container sets host to 0.0.0.0"
is "$output" ".*--cache-reuse 256" "should use cache"
if is_darwin; then
is "$output" ".*--flash-attn on" "use flash-attn on Darwin metal"
fi
run_ramalama -q --dryrun serve --seed abcd --host 127.0.0.1 ${model}
assert "$output" =~ ".*--host 127.0.0.1" "Outside container overrides host to 127.0.0.1"

View File

@@ -53,10 +53,6 @@ commands:
- name: "--no-webui"
description: "Disable the Web UI"
if: "{{ args.webui == 'off' }}"
- name: "--flash-attn"
description: "Set Flash Attention use"
value: "on"
if: "{{ host.uses_nvidia or host.uses_metal }}"
- name: "-ngl"
description: "Number of layers to offload to the GPU if available"
value: "{{ 999 if args.ngl < 0 else args.ngl }}"

View File

@@ -55,10 +55,6 @@ commands:
- name: "--no-webui"
description: "Disable the Web UI"
if: "{{ args.webui == 'off' }}"
- name: "--flash-attn"
description: "Set Flash Attention use"
value: "on"
if: "{{ host.uses_nvidia or host.uses_metal }}"
- name: "-ngl"
description: "Number of layers to offload to the GPU if available"
value: "{{ 999 if args.ngl < 0 else args.ngl }}"

View File

@@ -83,23 +83,23 @@ class FactoryInput:
[
(
FactoryInput(),
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
),
(
FactoryInput(has_mmproj=True),
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --mmproj /path/to/mmproj --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --mmproj /path/to/mmproj --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
),
(
FactoryInput(has_chat_template=False),
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
),
(
FactoryInput(cli_args=CLIArgs(runtime_args="")),
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on", # noqa: E501
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on", # noqa: E501
),
(
FactoryInput(cli_args=CLIArgs(max_tokens=99, runtime_args="")),
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on -n 99", # noqa: E501
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on -n 99", # noqa: E501
),
],
)