mirror of
https://github.com/containers/ramalama.git
synced 2026-02-05 06:46:39 +01:00
Merge pull request #2359 from olliewalsh/flash_attn
This commit is contained in:
@@ -55,10 +55,6 @@ commands:
|
||||
- name: "--no-webui"
|
||||
description: "Disable the Web UI"
|
||||
if: "{{ args.webui == 'off' }}"
|
||||
- name: "--flash-attn"
|
||||
description: "Set Flash Attention use"
|
||||
value: "on"
|
||||
if: "{{ host.uses_nvidia or host.uses_metal }}"
|
||||
- name: "-ngl"
|
||||
description: "Number of layers to offload to the GPU if available"
|
||||
value: "{{ 999 if args.ngl < 0 else args.ngl }}"
|
||||
|
||||
@@ -2,12 +2,13 @@ import argparse
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from ramalama.common import check_metal, check_nvidia, get_accel_env_vars
|
||||
from ramalama.common import check_metal, check_nvidia
|
||||
from ramalama.console import should_colorize
|
||||
from ramalama.transports.transport_factory import CLASS_MODEL_TYPES, New
|
||||
|
||||
|
||||
class RamalamaArgsContext:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.cache_reuse: Optional[int] = None
|
||||
self.container: Optional[bool] = None
|
||||
@@ -51,6 +52,7 @@ class RamalamaArgsContext:
|
||||
|
||||
|
||||
class RamalamaRagGenArgsContext:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.debug: bool | None = None
|
||||
self.format: str | None = None
|
||||
@@ -72,6 +74,7 @@ class RamalamaRagGenArgsContext:
|
||||
|
||||
|
||||
class RamalamaRagArgsContext:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.debug: bool | None = None
|
||||
self.port: str | None = None
|
||||
@@ -89,6 +92,7 @@ class RamalamaRagArgsContext:
|
||||
|
||||
|
||||
class RamalamaModelContext:
|
||||
|
||||
def __init__(self, model: CLASS_MODEL_TYPES, is_container: bool, should_generate: bool, dry_run: bool):
|
||||
self.model = model
|
||||
self.is_container = is_container
|
||||
@@ -124,6 +128,7 @@ class RamalamaModelContext:
|
||||
|
||||
|
||||
class RamalamaHostContext:
|
||||
|
||||
def __init__(
|
||||
self, is_container: bool, uses_nvidia: bool, uses_metal: bool, should_colorize: bool, rpc_nodes: Optional[str]
|
||||
):
|
||||
@@ -135,6 +140,7 @@ class RamalamaHostContext:
|
||||
|
||||
|
||||
class RamalamaCommandContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: RamalamaArgsContext | RamalamaRagGenArgsContext | RamalamaRagArgsContext,
|
||||
@@ -164,12 +170,9 @@ class RamalamaCommandContext:
|
||||
else:
|
||||
model = None
|
||||
|
||||
skip_gpu_probe = should_generate or bool(get_accel_env_vars())
|
||||
uses_nvidia = True if skip_gpu_probe else (check_nvidia() is None)
|
||||
|
||||
host = RamalamaHostContext(
|
||||
is_container,
|
||||
uses_nvidia,
|
||||
check_nvidia() is not None,
|
||||
check_metal(argparse.Namespace(**{"container": is_container})),
|
||||
should_colorize(),
|
||||
os.getenv("RAMALAMA_LLAMACPP_RPC_NODES", None),
|
||||
|
||||
@@ -16,7 +16,6 @@ from test.conftest import (
|
||||
skip_if_docker,
|
||||
skip_if_gh_actions_darwin,
|
||||
skip_if_no_container,
|
||||
skip_if_not_darwin,
|
||||
skip_if_ppc64le,
|
||||
skip_if_s390x,
|
||||
)
|
||||
@@ -182,10 +181,6 @@ def test_basic_dry_run():
|
||||
[], r".*--cache-reuse 256", None, None, True,
|
||||
id="check --cache-reuse default value", marks=skip_if_container
|
||||
),
|
||||
pytest.param(
|
||||
[], r".*--flash-attn", None, None, True,
|
||||
id="check --flash-attn", marks=[skip_if_container, skip_if_not_darwin]
|
||||
),
|
||||
pytest.param(
|
||||
["--host", "127.0.0.1"],
|
||||
r".*--host 127.0.0.1", None, None, True,
|
||||
|
||||
@@ -73,9 +73,6 @@ verify_begin=".*run --rm"
|
||||
run_ramalama -q --dryrun serve ${model}
|
||||
assert "$output" =~ ".*--host 0.0.0.0" "Outside container sets host to 0.0.0.0"
|
||||
is "$output" ".*--cache-reuse 256" "should use cache"
|
||||
if is_darwin; then
|
||||
is "$output" ".*--flash-attn on" "use flash-attn on Darwin metal"
|
||||
fi
|
||||
|
||||
run_ramalama -q --dryrun serve --seed abcd --host 127.0.0.1 ${model}
|
||||
assert "$output" =~ ".*--host 127.0.0.1" "Outside container overrides host to 127.0.0.1"
|
||||
|
||||
@@ -53,10 +53,6 @@ commands:
|
||||
- name: "--no-webui"
|
||||
description: "Disable the Web UI"
|
||||
if: "{{ args.webui == 'off' }}"
|
||||
- name: "--flash-attn"
|
||||
description: "Set Flash Attention use"
|
||||
value: "on"
|
||||
if: "{{ host.uses_nvidia or host.uses_metal }}"
|
||||
- name: "-ngl"
|
||||
description: "Number of layers to offload to the GPU if available"
|
||||
value: "{{ 999 if args.ngl < 0 else args.ngl }}"
|
||||
|
||||
@@ -55,10 +55,6 @@ commands:
|
||||
- name: "--no-webui"
|
||||
description: "Disable the Web UI"
|
||||
if: "{{ args.webui == 'off' }}"
|
||||
- name: "--flash-attn"
|
||||
description: "Set Flash Attention use"
|
||||
value: "on"
|
||||
if: "{{ host.uses_nvidia or host.uses_metal }}"
|
||||
- name: "-ngl"
|
||||
description: "Number of layers to offload to the GPU if available"
|
||||
value: "{{ 999 if args.ngl < 0 else args.ngl }}"
|
||||
|
||||
@@ -83,23 +83,23 @@ class FactoryInput:
|
||||
[
|
||||
(
|
||||
FactoryInput(),
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
|
||||
),
|
||||
(
|
||||
FactoryInput(has_mmproj=True),
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --mmproj /path/to/mmproj --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --mmproj /path/to/mmproj --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
|
||||
),
|
||||
(
|
||||
FactoryInput(has_chat_template=False),
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on --another-arg 44 --more-args", # noqa: E501
|
||||
),
|
||||
(
|
||||
FactoryInput(cli_args=CLIArgs(runtime_args="")),
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on", # noqa: E501
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on", # noqa: E501
|
||||
),
|
||||
(
|
||||
FactoryInput(cli_args=CLIArgs(max_tokens=99, runtime_args="")),
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v --flash-attn on -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on -n 99", # noqa: E501
|
||||
"llama-server --host 0.0.0.0 --port 1337 --log-file /var/tmp/ramalama.log --model /path/to/model --chat-template-file /path/to/chat-template --jinja --no-warmup --reasoning-budget 0 --alias library/smollm --ctx-size 512 --temp 11 --cache-reuse 1024 -v -ngl 44 --model-draft /path/to/draft-model -ngld 44 --threads 8 --seed 12345 --log-colors on -n 99", # noqa: E501
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user