diff --git a/ramalama/command/context.py b/ramalama/command/context.py index e92e9a66..3d0173be 100644 --- a/ramalama/command/context.py +++ b/ramalama/command/context.py @@ -2,13 +2,12 @@ import argparse import os from typing import Optional -from ramalama.common import check_metal, check_nvidia +from ramalama.common import check_metal, check_nvidia, get_accel_env_vars from ramalama.console import should_colorize from ramalama.transports.transport_factory import CLASS_MODEL_TYPES, New class RamalamaArgsContext: - def __init__(self) -> None: self.cache_reuse: Optional[int] = None self.container: Optional[bool] = None @@ -52,7 +51,6 @@ class RamalamaArgsContext: class RamalamaRagGenArgsContext: - def __init__(self) -> None: self.debug: bool | None = None self.format: str | None = None @@ -74,7 +72,6 @@ class RamalamaRagGenArgsContext: class RamalamaRagArgsContext: - def __init__(self) -> None: self.debug: bool | None = None self.port: str | None = None @@ -92,7 +89,6 @@ class RamalamaRagArgsContext: class RamalamaModelContext: - def __init__(self, model: CLASS_MODEL_TYPES, is_container: bool, should_generate: bool, dry_run: bool): self.model = model self.is_container = is_container @@ -128,7 +124,6 @@ class RamalamaModelContext: class RamalamaHostContext: - def __init__( self, is_container: bool, uses_nvidia: bool, uses_metal: bool, should_colorize: bool, rpc_nodes: Optional[str] ): @@ -140,7 +135,6 @@ class RamalamaHostContext: class RamalamaCommandContext: - def __init__( self, args: RamalamaArgsContext | RamalamaRagGenArgsContext | RamalamaRagArgsContext, @@ -169,9 +163,13 @@ class RamalamaCommandContext: model = cli_args.model else: model = None + + skip_gpu_probe = should_generate or bool(get_accel_env_vars()) + uses_nvidia = True if skip_gpu_probe else (check_nvidia() is None) + host = RamalamaHostContext( is_container, - check_nvidia() is None, + uses_nvidia, check_metal(argparse.Namespace(**{"container": is_container})), should_colorize(), os.getenv("RAMALAMA_LLAMACPP_RPC_NODES", None),