1
0
mirror of https://github.com/containers/ramalama.git synced 2026-02-05 15:47:26 +01:00

adds support for hosted chat providers

Signed-off-by: Ian Eaves <ian.k.eaves@gmail.com>
This commit is contained in:
Ian Eaves
2026-01-22 01:15:58 -06:00
parent 028ae02d9c
commit db8bb5d9df
34 changed files with 1770 additions and 392 deletions

View File

@@ -118,7 +118,7 @@ endif
.PHONY: lint
lint:
ifneq (,$(wildcard /usr/bin/python3))
/usr/bin/python3 -m compileall -q -x '\.venv' .
${PYTHON} -m compileall -q -x '\.venv' .
endif
! grep -ri $(EXCLUDE_OPTS) "#\!/usr/bin/python3" .
flake8 $(FLAKE8_ARGS) $(PROJECT_DIR) $(PYTHON_SCRIPTS)

View File

@@ -29,6 +29,9 @@ Show this help message and exit
#### **--list**
List the available models at an endpoint
#### **--max-tokens**=*integer*
Maximum number of tokens to generate. Set to 0 for unlimited output (default: 0).
#### **--mcp**=SERVER_URL
MCP (Model Context Protocol) servers to use for enhanced tool calling capabilities.
Can be specified multiple times to connect to multiple MCP servers.
@@ -49,6 +52,10 @@ When enabled, ramalama will periodically condense older messages into a summary,
keeping only recent messages and the summary. This prevents the context from growing
indefinitely during long chat sessions. Set to 0 to disable (default: 4).
#### **--temp**=*float*
Temperature of the response from the AI Model.
Lower numbers are more deterministic, higher numbers are more creative.
#### **--url**=URL
The host to send requests to (default: http://127.0.0.1:8080)

View File

@@ -17,13 +17,17 @@ ramalama\-run - run specified AI Model as a chatbot
| rlcr | rlcr:// | [`ramalama.com`](https://registry.ramalama.com) |
| OCI Container Registries | oci:// | [`opencontainers.org`](https://opencontainers.org)|
|||Examples: [`quay.io`](https://quay.io), [`Docker Hub`](https://docker.io),[`Artifactory`](https://artifactory.com)|
| Hosted API Providers | openai:// | [`api.openai.com`](https://api.openai.com)|
RamaLama defaults to the Ollama registry transport. This default can be overridden in the `ramalama.conf` file or via the RAMALAMA_TRANSPORTS
environment. `export RAMALAMA_TRANSPORT=huggingface` Changes RamaLama to use huggingface transport.
Modify individual model transports by specifying the `huggingface://`, `oci://`, `ollama://`, `https://`, `http://`, `file://` prefix to the model.
Modify individual model transports by specifying the `huggingface://`, `oci://`, `ollama://`, `https://`, `http://`, `file://`, or hosted API
prefix (`openai://`).
URL support means if a model is on a web site or even on your local system, you can run it directly.
Hosted API transports connect directly to the remote provider and bypass the local container runtime. In this mode, flags that tune local
containers (for example `--image`, GPU settings, or `--network`) do not apply, and the provider's own capabilities and security posture govern
the execution. URL support means if a model is on a web site or even on your local system, you can run it directly.
## OPTIONS

View File

@@ -37,7 +37,7 @@ RamaLama CLI defaults can be modified via ramalama.conf files. Default settings
### Test and run your models more securely
Because RamaLama defaults to running AI models inside of rootless containers using Podman on Docker. These containers isolate the AI models from information on the underlying host. With RamaLama containers, the AI model is mounted as a volume into the container in read/only mode. This results in the process running the model, llama.cpp or vLLM, being isolated from the host. In addition, since `ramalama run` uses the --network=none option, the container can not reach the network and leak any information out of the system. Finally, containers are run with --rm options which means that any content written during the running of the container is wiped out when the application exits.
Because RamaLama defaults to running AI models inside of rootless containers using Podman on Docker. These containers isolate the AI models from information on the underlying host. With RamaLama containers, the AI model is mounted as a volume into the container in read/only mode. This results in the process running the model, llama.cpp or vLLM, being isolated from the host. In addition, since `ramalama run` uses the --network=none option, the container can not reach the network and leak any information out of the system. Finally, containers are run with --rm options which means that any content written during the running of the container is wiped out when the application exits. Hosted API transports such as `openai://` bypass the container runtime entirely and connect directly to the remote provider; those transports inherit the provider's network access and security guarantees instead of RamaLama's container sandbox.
### Heres how RamaLama delivers a robust security footprint:

View File

@@ -205,12 +205,26 @@
# The maximum delay between retry attempts in seconds.
#
#max_retry_delay = 30
[ramalama.provider]
# Provider-specific hosted API configuration. Set per-provider options in the
# nested tables below.
[ramalama.provider.openai]
# Optional provider-specific API key used when contacting OpenAI-hosted
# transports. If unset, RamaLama falls back to the RAMALAMA_API_KEY value
# or environment variables expected by the provider.
#
#api_key = "sk-..."
[ramalama.user]
# Suppress the interactive prompt when running on macOS with a Podman VM
# that doesn't support GPU acceleration (e.g., applehv provider).
# When set to true, RamaLama will automatically proceed without GPU support
# instead of asking for confirmation.
# Can also be set via the `RAMALAMA_USER__NO_MISSING_GPU_PROMPT` environment variable.
#
[ramalama.user]
#no_missing_gpu_prompt = false

View File

@@ -253,6 +253,21 @@ The maximum number of times to retry a failed download
The maximum delay between retry attempts in seconds
## RAMALAMA.PROVIDER TABLE
The `ramalama.provider` table configures hosted API providers that RamaLama can proxy to.
`[[ramalama.provider]]`
**openai**=""
Configuration settings for the openai hosted provider
`[[ramalama.provider.openai]]`
**api_key**=""
Provider-specific API key used when invoking OpenAI-hosted transports. Overrides `RAMALAMA_API_KEY` when set.
## RAMALAMA.USER TABLE
The ramalama.user table contains user preference settings.

View File

@@ -27,6 +27,76 @@
"node": ">=18.0"
}
},
"node_modules/@ai-sdk/gateway": {
"version": "2.0.24",
"resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-2.0.24.tgz",
"integrity": "sha512-mflk80YF8hj8vrF9e1IHhovGKC1ubX+sY88pesSk3pUiXfH5VPO8dgzNnxjwsqsCZrnkHcztxS5cSl4TzSiEuA==",
"license": "Apache-2.0",
"dependencies": {
"@ai-sdk/provider": "2.0.1",
"@ai-sdk/provider-utils": "3.0.20",
"@vercel/oidc": "3.0.5"
},
"engines": {
"node": ">=18"
},
"peerDependencies": {
"zod": "^3.25.76 || ^4.1.8"
}
},
"node_modules/@ai-sdk/provider": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.1.tgz",
"integrity": "sha512-KCUwswvsC5VsW2PWFqF8eJgSCu5Ysj7m1TxiHTVA6g7k360bk0RNQENT8KTMAYEs+8fWPD3Uu4dEmzGHc+jGng==",
"license": "Apache-2.0",
"dependencies": {
"json-schema": "^0.4.0"
},
"engines": {
"node": ">=18"
}
},
"node_modules/@ai-sdk/provider-utils": {
"version": "3.0.20",
"resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.20.tgz",
"integrity": "sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==",
"license": "Apache-2.0",
"dependencies": {
"@ai-sdk/provider": "2.0.1",
"@standard-schema/spec": "^1.0.0",
"eventsource-parser": "^3.0.6"
},
"engines": {
"node": ">=18"
},
"peerDependencies": {
"zod": "^3.25.76 || ^4.1.8"
}
},
"node_modules/@ai-sdk/react": {
"version": "2.0.119",
"resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-2.0.119.tgz",
"integrity": "sha512-kl4CDAnKJ1z+Fc9cjwMQXLRqH5/gHhg8Jn9qW7sZ0LgL8VpiDmW+x+s8e588nE3eC88aL1OxOVyOE6lFYfWprw==",
"license": "Apache-2.0",
"dependencies": {
"@ai-sdk/provider-utils": "3.0.20",
"ai": "5.0.117",
"swr": "^2.2.5",
"throttleit": "2.1.0"
},
"engines": {
"node": ">=18"
},
"peerDependencies": {
"react": "^18 || ~19.0.1 || ~19.1.2 || ^19.2.1",
"zod": "^3.25.76 || ^4.1.8"
},
"peerDependenciesMeta": {
"zod": {
"optional": true
}
}
},
"node_modules/@algolia/abtesting": {
"version": "1.13.0",
"resolved": "https://registry.npmjs.org/@algolia/abtesting/-/abtesting-1.13.0.tgz",
@@ -2832,9 +2902,9 @@
}
},
"node_modules/@csstools/postcss-normalize-display-values": {
"version": "4.0.1",
"resolved": "https://registry.npmjs.org/@csstools/postcss-normalize-display-values/-/postcss-normalize-display-values-4.0.1.tgz",
"integrity": "sha512-TQUGBuRvxdc7TgNSTevYqrL8oItxiwPDixk20qCB5me/W8uF7BPbhRrAvFuhEoywQp/woRsUZ6SJ+sU5idZAIA==",
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/@csstools/postcss-normalize-display-values/-/postcss-normalize-display-values-4.0.0.tgz",
"integrity": "sha512-HlEoG0IDRoHXzXnkV4in47dzsxdsjdz6+j7MLjaACABX2NfvjFS6XVAnpaDyGesz9gK2SC7MbNwdCHusObKJ9Q==",
"funding": [
{
"type": "github",
@@ -5274,9 +5344,9 @@
}
},
"node_modules/@types/express-serve-static-core": {
"version": "4.19.8",
"resolved": "https://registry.npmjs.org/@types/express-serve-static-core/-/express-serve-static-core-4.19.8.tgz",
"integrity": "sha512-02S5fmqeoKzVZCHPZid4b8JH2eM5HzQLZWN2FohQEy/0eXTq8VXZfSN6Pcr3F6N9R/vNrj7cpgbhjie6m/1tCA==",
"version": "4.19.7",
"resolved": "https://registry.npmjs.org/@types/express-serve-static-core/-/express-serve-static-core-4.19.7.tgz",
"integrity": "sha512-FvPtiIf1LfhzsaIXhv/PHan/2FeQBbtBDtfX2QfvPxdUelMDEckK08SM6nqo1MIZY3RUlfA+HV8+hFUSio78qg==",
"license": "MIT",
"dependencies": {
"@types/node": "*",
@@ -5391,9 +5461,15 @@
"license": "MIT"
},
"node_modules/@types/node": {
<<<<<<< HEAD
"version": "25.0.10",
"resolved": "https://registry.npmjs.org/@types/node/-/node-25.0.10.tgz",
"integrity": "sha512-zWW5KPngR/yvakJgGOmZ5vTBemDoSqF3AcV/LrO5u5wTWyEAVVh+IT39G4gtyAkh3CtTZs8aX/yRM82OfzHJRg==",
=======
"version": "25.0.3",
"resolved": "https://registry.npmjs.org/@types/node/-/node-25.0.3.tgz",
"integrity": "sha512-W609buLVRVmeW693xKfzHeIV6nJGGz98uCPfeXI1ELMLXVeKYZ9m15fAMSaUPBHYLGFsVRcMmSCksQOrZV9BYA==",
>>>>>>> ca824d98 (adds support for hosted chat providers)
"license": "MIT",
"dependencies": {
"undici-types": "~7.16.0"
@@ -5418,9 +5494,15 @@
"license": "MIT"
},
"node_modules/@types/react": {
<<<<<<< HEAD
"version": "19.2.9",
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.9.tgz",
"integrity": "sha512-Lpo8kgb/igvMIPeNV2rsYKTgaORYdO1XGVZ4Qz3akwOj0ySGYMPlQWa8BaLn0G63D1aSaAQ5ldR06wCpChQCjA==",
=======
"version": "19.2.7",
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.7.tgz",
"integrity": "sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg==",
>>>>>>> ca824d98 (adds support for hosted chat providers)
"license": "MIT",
"dependencies": {
"csstype": "^3.2.2"
@@ -5825,6 +5907,27 @@
"node": ">=8"
}
},
<<<<<<< HEAD
=======
"node_modules/ai": {
"version": "5.0.117",
"resolved": "https://registry.npmjs.org/ai/-/ai-5.0.117.tgz",
"integrity": "sha512-uE6HNkdSwxbeHGKP/YbvapwD8fMOpj87wyfT9Z00pbzOh2fpnw5acak/4kzU00SX2vtI9K0uuy+9Tf9ytw5RwA==",
"license": "Apache-2.0",
"dependencies": {
"@ai-sdk/gateway": "2.0.24",
"@ai-sdk/provider": "2.0.1",
"@ai-sdk/provider-utils": "3.0.20",
"@opentelemetry/api": "1.9.0"
},
"engines": {
"node": ">=18"
},
"peerDependencies": {
"zod": "^3.25.76 || ^4.1.8"
}
},
>>>>>>> ca824d98 (adds support for hosted chat providers)
"node_modules/ajv": {
"version": "8.17.1",
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
@@ -6189,9 +6292,15 @@
"license": "MIT"
},
"node_modules/baseline-browser-mapping": {
<<<<<<< HEAD
"version": "2.9.17",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.17.tgz",
"integrity": "sha512-agD0MgJFUP/4nvjqzIB29zRPUuCF7Ge6mEv9s8dHrtYD7QWXRcx75rOADE/d5ah1NI+0vkDl0yorDd5U852IQQ==",
=======
"version": "2.9.11",
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.11.tgz",
"integrity": "sha512-Sg0xJUNDU1sJNGdfGWhVHX0kkZ+HWcvmVymJbj6NSgZZmW/8S9Y2HQ5euytnIgakgxN6papOAWiwDo1ctFDcoQ==",
>>>>>>> ca824d98 (adds support for hosted chat providers)
"license": "Apache-2.0",
"bin": {
"baseline-browser-mapping": "dist/cli.js"
@@ -6522,9 +6631,15 @@
}
},
"node_modules/caniuse-lite": {
<<<<<<< HEAD
"version": "1.0.30001765",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001765.tgz",
"integrity": "sha512-LWcNtSyZrakjECqmpP4qdg0MMGdN368D7X8XvvAqOcqMv0RxnlqVKZl2V6/mBR68oYMxOZPLw/gO7DuisMHUvQ==",
=======
"version": "1.0.30001762",
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001762.tgz",
"integrity": "sha512-PxZwGNvH7Ak8WX5iXzoK1KPZttBXNPuaOvI2ZYU7NrlM+d9Ov+TUvlLOBNGzVXAntMSMMlJPd+jY6ovrVjSmUw==",
>>>>>>> ca824d98 (adds support for hosted chat providers)
"funding": [
{
"type": "opencollective",
@@ -14784,9 +14899,9 @@
}
},
"node_modules/postcss-preset-env": {
"version": "10.6.1",
"resolved": "https://registry.npmjs.org/postcss-preset-env/-/postcss-preset-env-10.6.1.tgz",
"integrity": "sha512-yrk74d9EvY+W7+lO9Aj1QmjWY9q5NsKjK2V9drkOPZB/X6KZ0B3igKsHUYakb7oYVhnioWypQX3xGuePf89f3g==",
"version": "10.6.0",
"resolved": "https://registry.npmjs.org/postcss-preset-env/-/postcss-preset-env-10.6.0.tgz",
"integrity": "sha512-+LzpUSLCGHUdlZ1YZP7lp7w1MjxInJRSG0uaLyk/V/BM17iU2B7xTO7I8x3uk0WQAcLLh/ffqKzOzfaBvG7Fdw==",
"funding": [
{
"type": "github",
@@ -14824,7 +14939,7 @@
"@csstools/postcss-media-minmax": "^2.0.9",
"@csstools/postcss-media-queries-aspect-ratio-number-values": "^3.0.5",
"@csstools/postcss-nested-calc": "^4.0.0",
"@csstools/postcss-normalize-display-values": "^4.0.1",
"@csstools/postcss-normalize-display-values": "^4.0.0",
"@csstools/postcss-oklab-function": "^4.0.12",
"@csstools/postcss-position-area-property": "^1.0.0",
"@csstools/postcss-progressive-custom-properties": "^4.2.1",
@@ -16132,13 +16247,10 @@
"license": "MIT"
},
"node_modules/sax": {
"version": "1.4.4",
"resolved": "https://registry.npmjs.org/sax/-/sax-1.4.4.tgz",
"integrity": "sha512-1n3r/tGXO6b6VXMdFT54SHzT9ytu9yr7TaELowdYpMqY/Ao7EnlQGmAQ1+RatX7Tkkdm6hONI2owqNx2aZj5Sw==",
"license": "BlueOak-1.0.0",
"engines": {
"node": ">=11.0.0"
}
"version": "1.4.3",
"resolved": "https://registry.npmjs.org/sax/-/sax-1.4.3.tgz",
"integrity": "sha512-yqYn1JhPczigF94DMS+shiDMjDowYO6y9+wB/4WgO0Y19jWYk0lQ4tuG5KI7kj4FTp1wxPj5IFfcrz/s1c3jjQ==",
"license": "BlueOak-1.0.0"
},
"node_modules/scheduler": {
"version": "0.27.0",
@@ -18042,9 +18154,9 @@
}
},
"node_modules/webpack-dev-server/node_modules/ws": {
"version": "8.19.0",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz",
"integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==",
"version": "8.18.3",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz",
"integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==",
"license": "MIT",
"engines": {
"node": ">=10.0.0"

View File

@@ -2,9 +2,10 @@
import sys
from ramalama import cli
from ramalama.cli import HelpException, init_cli, print_version
from ramalama.common import perror
assert sys.version_info >= (3, 10), "Python 3.10 or greater is required."
__all__ = ["perror", "init_cli", "print_version", "HelpException"]
__all__ = ["cli", "perror", "init_cli", "print_version", "HelpException"]

View File

@@ -82,6 +82,8 @@ class ChatSubArgsType(Protocol):
rag: str | None
api_key: str | None
ARGS: List[str] | None
max_tokens: int | None
temp: float | None
ChatSubArgs = protocol_to_dataclass(ChatSubArgsType)

View File

@@ -5,7 +5,6 @@ import cmd
import copy
import itertools
import json
import os
import signal
import subprocess
import sys
@@ -13,13 +12,24 @@ import threading
import time
import urllib.error
import urllib.request
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import timedelta
from ramalama.arg_types import ChatArgsType
from ramalama.chat_providers import ChatProvider, ChatRequestOptions
from ramalama.chat_providers.openai import OpenAICompletionsChatProvider
from ramalama.chat_utils import (
AssistantMessage,
ChatMessageType,
SystemMessage,
ToolMessage,
UserMessage,
stream_response,
)
from ramalama.common import perror
from ramalama.config import CONFIG
from ramalama.console import EMOJI, should_colorize
from ramalama.console import should_colorize
from ramalama.engine import stop_container
from ramalama.file_loaders.file_manager import OpanAIChatAPIMessageBuilder
from ramalama.logger import logger
@@ -31,69 +41,6 @@ from ramalama.proxy_support import setup_proxy_support
setup_proxy_support()
def res(response, color):
color_default = ""
color_yellow = ""
if (color == "auto" and should_colorize()) or color == "always":
color_default = "\033[0m"
color_yellow = "\033[33m"
print("\r", end="")
assistant_response = ""
for line in response:
line = line.decode("utf-8").strip()
if line.startswith("data: {"):
choice = ""
json_line = json.loads(line[len("data: ") :])
if "choices" in json_line and json_line["choices"]:
choice = json_line["choices"][0]["delta"]
if "content" in choice:
choice = choice["content"]
else:
continue
if choice:
print(f"{color_yellow}{choice}{color_default}", end="", flush=True)
assistant_response += choice
print("")
return assistant_response
def default_prefix():
if not EMOJI:
return "> "
if CONFIG.prefix:
return CONFIG.prefix
engine = CONFIG.engine
if engine:
if os.path.basename(engine) == "podman":
return "🦭 > "
if os.path.basename(engine) == "docker":
return "🐋 > "
return "🦙 > "
def add_api_key(args, headers=None):
# static analyzers suggest for dict, this is a safer way of setting
# a default value, rather than using the parameter directly
headers = headers or {}
if getattr(args, "api_key", None):
api_key_min = 20
if len(args.api_key) < api_key_min:
perror("Warning: Provided API key is invalid.")
headers["Authorization"] = f"Bearer {args.api_key}"
return headers
@dataclass
class ChatOperationalArgs:
initial_connection: bool = False
@@ -102,18 +49,70 @@ class ChatOperationalArgs:
monitor: "ServerMonitor | None" = None
class Spinner:
def __init__(self, wait_time: float = 0.1):
self._stop_event: threading.Event = threading.Event()
self._thread: threading.Thread | None = None
self.wait_time = wait_time
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
return False
def start(self) -> "Spinner":
if not sys.stdout.isatty():
return self
if self._thread is not None:
self.stop()
self._thread = threading.Thread(target=self._spinner_loop, daemon=True)
self._thread.start()
return self
def stop(self):
if self._thread is None:
return
self._stop_event.set()
self._thread.join(timeout=0.2)
perror("\r", end="", flush=True)
self._thread = None
self._stop_event = threading.Event()
def _spinner_loop(self):
frames = ['', '', '', '', '', '', '', '', '', '']
for frame in itertools.cycle(frames):
if self._stop_event.is_set():
break
perror(f"\r{frame}", end="", flush=True)
self._stop_event.wait(self.wait_time)
class RamaLamaShell(cmd.Cmd):
def __init__(self, args: ChatArgsType, operational_args: ChatOperationalArgs | None = None):
def __init__(
self,
args: ChatArgsType,
operational_args: ChatOperationalArgs | None = None,
provider: ChatProvider | None = None,
):
if operational_args is None:
operational_args = ChatOperationalArgs()
super().__init__()
self.conversation_history: list[dict] = []
self.conversation_history: list[ChatMessageType] = []
self.args = args
self.operational_args = operational_args
self.request_in_process = False
self.prompt = args.prefix
self.url = f"{args.url}/chat/completions"
self.provider = provider or OpenAICompletionsChatProvider(args.url, getattr(args, "api_key", None))
self.url = self.provider.build_url()
self.prep_rag_message()
self.mcp_agent: LLMAgent | None = None
self.initialize_mcp()
@@ -131,7 +130,7 @@ class RamaLamaShell(cmd.Cmd):
def _summarize_conversation(self):
"""Summarize the conversation history to prevent context growth."""
if len(self.conversation_history) < 4:
if len(self.conversation_history) < 10:
# Need at least a few messages to summarize
return
@@ -145,16 +144,14 @@ class RamaLamaShell(cmd.Cmd):
return
# Create a summarization prompt
conversation_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_to_summarize])
summary_prompt = {
"role": "user",
"content": (
f"Please provide a concise summary of the following conversation, "
conversation_text = "\n".join([self._format_message_for_summary(msg) for msg in messages_to_summarize])
summary_prompt = UserMessage(
text=(
"Please provide a concise summary of the following conversation, "
f"preserving key information and context:\n\n{conversation_text}\n\n"
f"Provide only the summary, without any preamble."
),
}
"Provide only the summary, without any preamble."
)
)
# Make API call to get summary
# Provide user feedback during summarization
@@ -167,12 +164,12 @@ class RamaLamaShell(cmd.Cmd):
summary = result['choices'][0]['message']['content']
# Rebuild conversation history with summary
new_history = []
new_history: list[ChatMessageType] = []
if first_msg:
new_history.append(first_msg)
# Add summary as a system message
new_history.append({"role": "system", "content": f"Previous conversation summary: {summary}"})
new_history.append(SystemMessage(text=f"Previous conversation summary: {summary}"))
# Add recent messages
new_history.extend(recent_msgs)
@@ -196,35 +193,42 @@ class RamaLamaShell(cmd.Cmd):
self._summarize_conversation()
self.message_count = 0 # Reset counter after summarization
def _make_api_request(self, messages, stream=True):
"""Make an API request with the given messages.
def _history_snapshot(self) -> list[dict[str, str]]:
return [
{"role": msg.role, "content": self._format_message_for_summary(msg)} for msg in self.conversation_history
]
Args:
messages: List of message dicts to send
stream: Whether to stream the response
def _format_message_for_summary(self, msg: ChatMessageType) -> str:
content = msg.text or ""
if isinstance(msg, AssistantMessage):
if msg.tool_calls:
content += f"\n[tool_calls: {', '.join(call.name for call in msg.tool_calls)}]"
Returns:
urllib.request.Request object
"""
data = {
"stream": stream,
"messages": messages,
}
if getattr(self.args, "model", None):
data["model"] = self.args.model
if getattr(self.args, "temp", None):
data["temperature"] = float(self.args.temp)
if stream and getattr(self.args, "max_tokens", None):
data["max_completion_tokens"] = self.args.max_tokens
if isinstance(msg, ToolMessage):
content = f"\n[tool_response: {msg.text}]"
headers = add_api_key(self.args)
headers["Content-Type"] = "application/json"
return f"{msg.role}: {content}".strip()
return urllib.request.Request(
self.url,
data=json.dumps(data).encode('utf-8'),
headers=headers,
method="POST",
def _make_api_request(self, messages: Sequence[ChatMessageType], stream: bool = True):
"""Create a provider request for arbitrary message lists."""
max_tokens = self.args.max_tokens if stream and getattr(self.args, "max_tokens", None) else None
options = self._build_request_options(stream=stream, max_tokens=max_tokens)
return self.provider.create_request(messages, options)
def _resolve_model_name(self) -> str | None:
if getattr(self.args, "runtime", None) == "mlx":
return None
return getattr(self.args, "model", None)
def _build_request_options(self, *, stream: bool, max_tokens: int | None) -> ChatRequestOptions:
temperature = getattr(self.args, "temp", None)
if max_tokens is not None and max_tokens <= 0:
max_tokens = None
return ChatRequestOptions(
model=self._resolve_model_name(),
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
)
def initialize_mcp(self):
@@ -281,7 +285,7 @@ class RamaLamaShell(cmd.Cmd):
"""Determine if the request should be handled by MCP tools."""
if not self.mcp_agent:
return False
return self.mcp_agent.should_use_tools(content, self.conversation_history)
return self.mcp_agent.should_use_tools(content, self._history_snapshot())
def _handle_mcp_request(self, content: str) -> str:
"""Handle a request using MCP tools (multi-tool capable, automatic)."""
@@ -325,8 +329,8 @@ class RamaLamaShell(cmd.Cmd):
print(f"\n {r['tool']} -> {r['output']}")
# Save to history
self.conversation_history.append({"role": "user", "content": f"/tool {question}"})
self.conversation_history.append({"role": "assistant", "content": str(responses)})
self.conversation_history.append(UserMessage(text=f"/tool {question}"))
self.conversation_history.append(AssistantMessage(text=str(responses)))
def _select_tools(self):
"""Interactive multi-tool selection without prompting for arguments."""
@@ -395,41 +399,26 @@ class RamaLamaShell(cmd.Cmd):
# If streaming, _handle_mcp_request already printed output
if isinstance(response, str) and response.strip():
print(response)
self.conversation_history.append({"role": "user", "content": content})
self.conversation_history.append({"role": "assistant", "content": response})
self.conversation_history.append(UserMessage(text=content))
self.conversation_history.append(AssistantMessage(text=response))
self._check_and_summarize()
return False
self.conversation_history.append({"role": "user", "content": content})
self.conversation_history.append(UserMessage(text=content))
self.request_in_process = True
response = self._req()
if response:
self.conversation_history.append({"role": "assistant", "content": response})
self.conversation_history.append(AssistantMessage(text=response))
self.request_in_process = False
self._check_and_summarize()
def _make_request_data(self):
data = {
"stream": True,
"messages": self.conversation_history,
}
if getattr(self.args, "temp", None):
data["temperature"] = float(self.args.temp)
if getattr(self.args, "max_tokens", None):
data["max_completion_tokens"] = self.args.max_tokens
# For MLX runtime, omit explicit model to allow server default ("default_model")
if getattr(self.args, "runtime", None) != "mlx" and self.args.model is not None:
data["model"] = self.args.model
json_data = json.dumps(data).encode("utf-8")
headers = {
"Content-Type": "application/json",
}
headers = add_api_key(self.args, headers)
logger.debug("Request: URL=%s, Data=%s, Headers=%s", self.url, json_data, headers)
request = urllib.request.Request(self.url, data=json_data, headers=headers, method="POST")
options = self._build_request_options(
stream=True,
max_tokens=getattr(self.args, "max_tokens", None),
)
request = self.provider.create_request(self.conversation_history, options)
logger.debug("Request: URL=%s, Data=%s, Headers=%s", request.full_url, request.data, request.headers)
return request
def _req(self):
@@ -442,28 +431,46 @@ class RamaLamaShell(cmd.Cmd):
# Adjust timeout based on whether we're in initial connection phase
max_timeout = 30 if getattr(self.args, "initial_connection", False) else 16
for c in itertools.cycle(['', '', '', '', '', '', '', '', '', '']):
last_error: Exception | None = None
spinner = Spinner().start()
while True:
try:
response = urllib.request.urlopen(request)
spinner.stop()
break
except Exception:
if sys.stdout.isatty():
perror(f"\r{c}", end="", flush=True)
except urllib.error.HTTPError as http_err:
error_body = http_err.read().decode("utf-8", "ignore").strip()
message = f"HTTP {http_err.code}"
if error_body:
message = f"{message}: {error_body}"
perror(f"\r{message}")
if total_time_slept > max_timeout:
break
self.kills()
spinner.stop()
return None
except Exception as exc:
last_error = exc
total_time_slept += i
time.sleep(i)
if total_time_slept > max_timeout:
break
i = min(i * 2, 0.1)
total_time_slept += i
time.sleep(i)
i = min(i * 2, 0.1)
spinner.stop()
if response:
return res(response, self.args.color)
return stream_response(response, self.args.color, self.provider)
# Only show error and kill if not in initial connection phase
if not getattr(self.args, "initial_connection", False):
perror(f"\rError: could not connect to: {self.url}")
error_suffix = ""
if last_error:
error_suffix = f" ({last_error})"
perror(f"\rError: could not connect to: {self.url}{error_suffix}")
self.kills()
else:
logger.debug(f"Could not connect to: {self.url}")
@@ -721,12 +728,20 @@ def _report_server_exit(monitor):
perror("Check server logs for more details about why the service exited.")
def chat(args: ChatArgsType, operational_args: ChatOperationalArgs | None = None):
def chat(
args: ChatArgsType,
operational_args: ChatOperationalArgs | None = None,
provider: ChatProvider | None = None,
):
if args.dryrun:
assert args.ARGS is not None
prompt = " ".join(args.ARGS)
print(f"\nramalama chat --color {args.color} --prefix \"{args.prefix}\" --url {args.url} {prompt}")
return
if provider is None:
provider = OpenAICompletionsChatProvider(args.url, getattr(args, "api_key", None))
# SIGALRM is Unix-only, skip keepalive timeout handling on Windows
if getattr(args, "keepalive", False) and hasattr(signal, 'SIGALRM'):
signal.signal(signal.SIGALRM, alarm_handler)
@@ -752,14 +767,12 @@ def chat(args: ChatArgsType, operational_args: ChatOperationalArgs | None = None
monitor.start()
list_models = getattr(args, "list", False)
if list_models:
url = f"{args.url}/models"
headers = add_api_key(args)
req = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(req) as response:
data = json.loads(response.read())
ids = [model["id"] for model in data.get("data", [])]
for id in ids:
print(id)
for model_id in provider.list_models():
print(model_id)
monitor.stop()
if hasattr(signal, 'alarm'):
signal.alarm(0)
return
# Ensure operational_args is initialized
if operational_args is None:
@@ -770,7 +783,7 @@ def chat(args: ChatArgsType, operational_args: ChatOperationalArgs | None = None
successful_exit = True
try:
shell = RamaLamaShell(args, operational_args)
shell = RamaLamaShell(args, operational_args, provider=provider)
if shell.handle_args(monitor):
return

View File

@@ -0,0 +1,9 @@
from ramalama.chat_providers import api_providers, openai
from ramalama.chat_providers.base import (
ChatProvider,
ChatProviderError,
ChatRequestOptions,
ChatStreamEvent,
)
__all__ = ["ChatProvider", "ChatProviderError", "ChatRequestOptions", "ChatStreamEvent", "openai", "api_providers"]

View File

@@ -0,0 +1,33 @@
from collections.abc import Callable
from ramalama.chat_providers.base import ChatProvider
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
from ramalama.config import CONFIG
PROVIDER_API_KEY_RESOLVERS: dict[str, Callable[[], str | None]] = {
"openai": lambda: CONFIG.provider.openai.api_key,
}
def get_provider_api_key(scheme: str) -> str | None:
"""Return a configured API key for the given provider scheme, if any."""
if resolver := PROVIDER_API_KEY_RESOLVERS.get(scheme):
return resolver()
return CONFIG.api_key
DEFAULT_PROVIDERS = {
"openai": lambda: OpenAIResponsesChatProvider(
base_url="https://api.openai.com/v1", api_key=get_provider_api_key("openai")
)
}
def get_chat_provider(scheme: str) -> ChatProvider:
if (resolver := DEFAULT_PROVIDERS.get(scheme, None)) is None:
raise ValueError(f"No support chat providers for {scheme}")
return resolver()
__all__ = ["get_chat_provider", "get_provider_api_key"]

View File

@@ -0,0 +1,201 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any
from urllib import error as urllib_error
from urllib import request as urllib_request
from ramalama.chat_utils import ChatMessageType
from ramalama.config import CONFIG
@dataclass(slots=True)
class ChatRequestOptions:
"""Normalized knobs for building a chat completion request."""
model: str | None = None
temperature: float | None = None
max_tokens: int | None = None
stream: bool = True
extra: dict[str, Any] | None = None
def to_dict(self) -> dict[str, Any]:
keys = ["model", "temperature", "max_tokens", "stream"]
result = {k: v for k in keys if (v := getattr(self, k)) is not None}
result |= {} if self.extra is None else dict(self.extra)
return result
@dataclass(slots=True)
class ChatStreamEvent:
"""A provider-agnostic representation of a streamed delta."""
text: str | None = None
raw: dict[str, Any] | None = None
done: bool = False
class ChatProviderError(Exception):
"""Raised when a provider request fails or returns an invalid payload."""
def __init__(self, message: str, *, status_code: int | None = None, payload: Any | None = None):
super().__init__(message)
self.status_code = status_code
self.payload = payload
class ChatProvider(ABC):
"""Abstract base class for hosted chat providers."""
provider: str = "base"
default_path: str
def __init__(
self,
base_url: str,
api_key: str | None = None,
default_headers: Mapping[str, str] | None = None,
) -> None:
if api_key is None:
api_key = CONFIG.api_key
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self._default_headers: dict[str, str] = dict(default_headers or {})
def build_url(self, path: str | None = None) -> str:
rel = path or self.default_path
if not rel.startswith("/"):
rel = f"/{rel}"
return f"{self.base_url}{rel}"
def prepare_headers(
self,
*,
include_auth: bool = True,
extra: dict[str, str] | None = None,
options: ChatRequestOptions | None = None,
) -> dict[str, str]:
headers: dict[str, str] = {
"Content-Type": "application/json",
**self._default_headers,
**self.provider_headers(options),
}
if include_auth:
headers.update(self.auth_headers())
if extra:
headers.update(extra)
return headers
def auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
def serialize_payload(self, payload: Mapping[str, Any]) -> bytes:
return json.dumps(payload).encode("utf-8")
def create_request(
self, messages: Sequence[ChatMessageType], options: ChatRequestOptions
) -> urllib_request.Request:
payload = self.build_payload(messages, options)
headers = self.prepare_headers(options=options, extra=self.additional_request_headers(options))
body = self.serialize_payload(payload)
return urllib_request.Request(
self.build_url(self.resolve_request_path(options)),
data=body,
headers=headers,
method="POST",
)
# ------------------------------------------------------------------
# Provider customization points
# ------------------------------------------------------------------
def provider_headers(self, options: ChatRequestOptions | None = None) -> dict[str, str]:
return {}
def additional_request_headers(self, options: ChatRequestOptions | None = None) -> dict[str, str]:
return {}
def resolve_request_path(self, options: ChatRequestOptions | None = None) -> str:
return self.default_path
@abstractmethod
def build_payload(self, messages: Sequence[ChatMessageType], options: ChatRequestOptions) -> Mapping[str, Any]:
"""Return the provider-specific payload."""
@abstractmethod
def parse_stream_chunk(self, chunk: bytes) -> Iterable[ChatStreamEvent]:
"""Yield zero or more events parsed from a streamed response chunk."""
# ------------------------------------------------------------------
# Error handling
# ------------------------------------------------------------------
def raise_for_status(self, status_code: int, payload: Any | None = None) -> None:
if status_code >= 400:
if isinstance(payload, dict) and "error" in payload:
err = payload["error"]
message = str(err.get("message") or err.get("type") or err) if isinstance(err, dict) else str(err)
else:
message = "chat request failed"
raise ChatProviderError(message, status_code=status_code, payload=payload)
# ------------------------------------------------------------------
# Non-streamed helpers
# ------------------------------------------------------------------
def parse_response_body(self, body: bytes) -> Any:
if not body:
return None
return json.loads(body.decode("utf-8"))
def list_models(self) -> list[str]:
"""Return available model identifiers exposed by the provider."""
request = urllib_request.Request(
self.build_url("/models"),
headers=self.prepare_headers(include_auth=True),
method="GET",
)
try:
with urllib_request.urlopen(request) as response:
payload = self.parse_response_body(response.read())
except urllib_error.HTTPError as exc:
if exc.code in (401, 403):
message = (
f"Could not authenticate with {self.provider}."
"The provided API key was either missing or invalid.\n"
f"Set RAMALAMA_API_KEY or ramalama.provider.<provider_name>.api_key."
)
try:
payload = self.parse_response_body(exc.read())
except Exception:
payload = {}
if details := payload.get("error", {}).get("message", None):
message = f"{message}\n\n{details}"
raise ChatProviderError(message, status_code=exc.code) from exc
raise
if not isinstance(payload, Mapping):
raise ChatProviderError("Invalid model list payload", payload=payload)
data = payload.get("data")
if not isinstance(data, list):
raise ChatProviderError("Invalid model list payload", payload=payload)
models: list[str] = []
for entry in data:
if isinstance(entry, Mapping) and (model_id := entry.get("id")):
models.append(str(model_id))
return models
__all__ = [
"ChatProvider",
"ChatProviderError",
"ChatRequestOptions",
"ChatStreamEvent",
]

View File

@@ -0,0 +1,338 @@
import json
from collections.abc import Iterable, Mapping, Sequence
from functools import singledispatch
from typing import Any, TypedDict
from ramalama.chat_providers.base import ChatProvider, ChatRequestOptions, ChatStreamEvent
from ramalama.chat_utils import (
AssistantMessage,
AttachmentPart,
ChatMessageType,
SystemMessage,
ToolMessage,
UserMessage,
serialize_part,
)
class UnsupportedMessageType(Exception):
"""Raised when a provider request fails or returns an invalid payload."""
@singledispatch
def message_to_completions_dict(message: Any) -> dict[str, Any]:
message = (
f"Cannot convert message type `{type(message)}` to a completions dictionary.\n"
"Please create an issue at: https://github.com/containers/ramalama/issues"
)
raise UnsupportedMessageType(message)
@message_to_completions_dict.register
def _(message: SystemMessage) -> dict[str, Any]:
return {**message.metadata, 'content': message.text or "", 'role': message.role}
@message_to_completions_dict.register
def _(message: ToolMessage) -> dict[str, Any]:
response = {
**message.metadata,
'content': message.text or "",
'role': message.role,
}
if message.tool_call_id:
response['tool_call_id'] = message.tool_call_id
return response
@message_to_completions_dict.register
def _(message: UserMessage) -> dict[str, Any]:
if message.attachments:
raise ValueError("Attachments are not supported by this provider.")
return {**message.metadata, 'content': message.text or "", 'role': message.role}
@message_to_completions_dict.register
def _(message: AssistantMessage) -> dict[str, Any]:
if message.attachments:
raise ValueError("Attachments are not supported by this provider.")
tool_calls = [
{
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": json.dumps(call.arguments, ensure_ascii=False),
},
}
for call in message.tool_calls
]
return {**message.metadata, 'content': message.text or "", 'role': message.role, 'tool_calls': tool_calls}
class CompletionsPayload(TypedDict, total=False):
messages: list[dict[str, Any]]
model: str | None
temperature: float | None
max_tokens: int | None
stream: bool
class OpenAICompletionsChatProvider(ChatProvider):
provider = "openai"
default_path = "/chat/completions"
def __init__(self, base_url: str, api_key: str | None = None):
super().__init__(base_url, api_key)
self._stream_buffer: str = ""
def build_payload(self, messages: Sequence[ChatMessageType], options: ChatRequestOptions) -> CompletionsPayload:
payload: CompletionsPayload = {
"messages": [message_to_completions_dict(m) for m in messages],
"model": options.model,
"temperature": options.temperature,
"max_tokens": options.max_tokens,
"stream": options.stream,
}
return payload
def parse_stream_chunk(self, chunk: bytes) -> Iterable[ChatStreamEvent]:
events: list[ChatStreamEvent] = []
self._stream_buffer += chunk.decode("utf-8")
while "\n\n" in self._stream_buffer:
raw_event, self._stream_buffer = self._stream_buffer.split("\n\n", 1)
raw_event = raw_event.strip()
if not raw_event:
continue
for line in raw_event.splitlines():
if not line.startswith("data:"):
continue
payload = line[len("data:") :].strip()
if not payload:
continue
if payload == "[DONE]":
events.append(ChatStreamEvent(done=True))
continue
try:
parsed = json.loads(payload)
except json.JSONDecodeError:
continue
if delta := self._extract_delta(parsed):
events.append(ChatStreamEvent(text=delta, raw=parsed))
return events
def _extract_delta(self, payload: Mapping[str, object]) -> str | None:
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return None
choice = choices[0]
if not isinstance(choice, Mapping):
return None
delta = choice.get("delta")
if not isinstance(delta, Mapping):
return None
content = delta.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for entry in content:
if not isinstance(entry, Mapping):
continue
entry_type = entry.get("type")
text_value = entry.get("text")
if entry_type in {"text", "output_text"} and isinstance(text_value, str):
parts.append(text_value)
if parts:
return "".join(parts)
return None
@singledispatch
def message_to_responses_dict(message: Any) -> dict[str, Any]:
raise ValueError(f"Undefined message type {type(message)}")
def create_responses_content(
text: str | None, attachments: list[AttachmentPart], content_type: str
) -> list[dict[str, Any]] | str:
"""
TODO: Current structure doesn't correctly reflect document ordering
(i.e. the possibility of messages interspersed with content)
"""
content: list[dict[str, Any]] = []
if text:
content.append({"type": content_type, "text": text})
for attachment in attachments:
content.append(serialize_part(attachment))
return content or ""
@message_to_responses_dict.register
def _(message: SystemMessage) -> dict[str, Any]:
return {**message.metadata, 'content': message.text or "", 'role': message.role}
@message_to_responses_dict.register
def _(message: ToolMessage) -> dict[str, Any]:
response = {
**message.metadata,
'content': message.text or "",
'role': message.role,
}
if message.tool_call_id:
response['tool_call_id'] = message.tool_call_id
return response
@message_to_responses_dict.register
def _(message: UserMessage) -> dict[str, Any]:
return {
**message.metadata,
'content': create_responses_content(message.text, message.attachments, "input_text"),
'role': message.role,
}
@message_to_responses_dict.register
def _(message: AssistantMessage) -> dict[str, Any]:
payload: dict[str, Any] = {
**message.metadata,
'content': create_responses_content(message.text, message.attachments, "output_text"),
'role': message.role,
}
tool_calls = [
{
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": json.dumps(call.arguments, ensure_ascii=False),
},
}
for call in message.tool_calls
]
if tool_calls:
payload['tool_calls'] = tool_calls
return payload
class ResponsesPayload(TypedDict, total=False):
input: list[dict[str, Any]]
model: str
temperature: float | None
max_completion_tokens: int
stream: bool
class OpenAIResponsesChatProvider(ChatProvider):
provider = "openai"
default_path: str = "/responses"
def __init__(self, base_url: str, api_key: str | None = None):
super().__init__(base_url, api_key)
self._stream_buffer: str = ""
def build_payload(self, messages: Sequence[ChatMessageType], options: ChatRequestOptions) -> ResponsesPayload:
if options.model is None:
raise ValueError("Chat options require a model value")
payload: ResponsesPayload = {
"input": [message_to_responses_dict(m) for m in messages],
"model": options.model,
"temperature": options.temperature,
"stream": options.stream,
}
if options.max_tokens is not None and options.max_tokens > 0:
payload["max_completion_tokens"] = options.max_tokens
return payload
def parse_stream_chunk(self, chunk: bytes) -> Iterable[ChatStreamEvent]:
events: list[ChatStreamEvent] = []
self._stream_buffer += chunk.decode("utf-8")
while "\n\n" in self._stream_buffer:
raw_event, self._stream_buffer = self._stream_buffer.split("\n\n", 1)
raw_event = raw_event.strip()
if not raw_event:
continue
event_type = ""
data_lines: list[str] = []
for line in raw_event.splitlines():
if line.startswith("event:"):
event_type = line[len("event:") :].strip()
elif line.startswith("data:"):
data_lines.append(line[len("data:") :].strip())
data = "\n".join(data_lines).strip()
if not data:
continue
if data == "[DONE]":
events.append(ChatStreamEvent(done=True))
continue
try:
payload = json.loads(data)
except json.JSONDecodeError:
continue
if self._is_completion_event(event_type, payload):
events.append(ChatStreamEvent(done=True, raw=payload))
continue
if text := self._extract_responses_delta(event_type, payload):
events.append(ChatStreamEvent(text=text, raw=payload))
return events
@staticmethod
def _is_completion_event(event_type: str, payload: Mapping[str, Any]) -> bool:
hinted_type = event_type or (payload.get("type") if isinstance(payload, Mapping) else "")
return hinted_type == "response.completed"
@staticmethod
def _extract_responses_delta(event_type: str, payload: Mapping[str, Any]) -> str | None:
if not event_type:
event_type = payload.get("type", "") if isinstance(payload, Mapping) else ""
if event_type == "response.output_text.delta":
delta = payload.get("delta")
if isinstance(delta, Mapping):
text = delta.get("text")
if isinstance(text, str):
return text
elif isinstance(delta, str):
return delta
if event_type == "response.output_text.done":
output = payload.get("output")
if isinstance(output, list) and output:
first = output[0]
if isinstance(first, Mapping):
content = first.get("content")
if isinstance(content, list) and content:
entry = content[0]
if isinstance(entry, Mapping):
text = entry.get("text")
if isinstance(text, str):
return text
return None
__all__ = ["OpenAICompletionsChatProvider", "OpenAIResponsesChatProvider"]

137
ramalama/chat_utils.py Normal file
View File

@@ -0,0 +1,137 @@
import base64
import os
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Literal, Protocol
from ramalama.config import CONFIG
from ramalama.console import EMOJI, should_colorize
RoleType = Literal["system", "user", "assistant", "tool"]
@dataclass(slots=True)
class ImageURLPart:
url: str
detail: str | None = None
type: Literal["image_url"] = "image_url"
@dataclass(slots=True)
class ImageBytesPart:
data: bytes
mime_type: str = "application/octet-stream"
type: Literal["image_bytes"] = "image_bytes"
@dataclass(slots=True)
class ToolCall:
id: str
name: str
arguments: dict[str, Any]
AttachmentPart = ImageURLPart | ImageBytesPart
@dataclass(slots=True)
class SystemMessage:
role: Literal["system"] = "system"
text: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class UserMessage:
role: Literal["user"] = "user"
text: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
attachments: list[AttachmentPart] = field(default_factory=list)
@dataclass(slots=True)
class AssistantMessage:
role: Literal["assistant"] = "assistant"
text: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
tool_calls: list[ToolCall] = field(default_factory=list)
attachments: list[AttachmentPart] = field(default_factory=list)
@dataclass(slots=True)
class ToolMessage:
text: str
role: Literal["tool"] = "tool"
metadata: dict[str, Any] = field(default_factory=dict)
tool_call_id: str | None = None
ChatMessageType = SystemMessage | UserMessage | AssistantMessage | ToolMessage
class StreamParser(Protocol):
def parse_stream_chunk(self, chunk: bytes) -> Iterable[Any]: ...
def stream_response(chunks: Iterable[bytes], color: str, provider: StreamParser) -> str:
color_default = ""
color_yellow = ""
if (color == "auto" and should_colorize()) or color == "always":
color_default = "\033[0m"
color_yellow = "\033[33m"
print("\r", end="")
assistant_response = ""
for chunk in chunks:
events = provider.parse_stream_chunk(chunk)
for event in events:
text = getattr(event, "text", None)
if not text:
continue
print(f"{color_yellow}{text}{color_default}", end="", flush=True)
assistant_response += text
print("")
return assistant_response
def default_prefix() -> str:
if not EMOJI:
return "> "
if CONFIG.prefix:
return CONFIG.prefix
if engine := CONFIG.engine:
if os.path.basename(engine) == "podman":
return "🦭 > "
if os.path.basename(engine) == "docker":
return "🐋 > "
return "🦙 > "
def serialize_part(part: AttachmentPart) -> dict[str, Any]:
if isinstance(part, ImageURLPart):
payload: dict[str, Any] = {"url": part.url}
if part.detail:
payload["detail"] = part.detail
return {"type": "image_url", "image_url": payload}
if isinstance(part, ImageBytesPart):
return {
"type": "image_bytes",
"image_bytes": {"data": base64.b64encode(part.data).decode("ascii"), "mime_type": part.mime_type},
}
raise TypeError(f"Unsupported message part: {part!r}")
__all__ = [
"ToolCall",
"ImageURLPart",
"ImageBytesPart",
"default_prefix",
"stream_response",
"serialize_part",
]

View File

@@ -24,7 +24,7 @@ except Exception:
import ramalama.chat as chat
from ramalama import engine
from ramalama.arg_types import DefaultArgsType
from ramalama.chat import default_prefix
from ramalama.chat_utils import default_prefix
from ramalama.cli_arg_normalization import normalize_pull_arg
from ramalama.command.factory import assemble_command
from ramalama.common import accel_image, get_accel, perror
@@ -48,6 +48,7 @@ from ramalama.path_utils import file_uri_to_path
from ramalama.rag import INPUT_DIR, Rag, RagTransport, rag_image
from ramalama.shortnames import Shortnames
from ramalama.stack import Stack
from ramalama.transports.api import APITransport
from ramalama.transports.base import (
MODEL_TYPES,
NoGGUFModelFileFound,
@@ -1033,7 +1034,8 @@ If GPU device on host is accessible to via group access, this option leaks the u
)
parser.add_argument(
"--temp",
default=CONFIG.temp,
type=float,
default=float(CONFIG.temp),
help="temperature of the response from the AI model",
completer=suppressCompleter,
)
@@ -1118,6 +1120,19 @@ def chat_parser(subparsers):
parser.add_argument("--url", type=str, default="http://127.0.0.1:8080/v1", help="the url to send requests to")
parser.add_argument("--model", "-m", type=str, completer=local_models, help="model for inferencing")
parser.add_argument("--rag", type=str, help="a file or directory to use as context for the chat")
parser.add_argument(
"--max-tokens",
dest="max_tokens",
type=int,
default=CONFIG.max_tokens,
help="maximum number of tokens to generate (0 = unlimited)",
)
parser.add_argument(
"--temp",
type=float,
default=float(CONFIG.temp),
help="temperature of the response from the AI model",
)
parser.add_argument(
"ARGS", nargs="*", help="overrides the default prompt, and the output is returned without entering the chatbot"
)
@@ -1165,7 +1180,6 @@ def run_cli(args):
try:
# detect available port and update arguments
args.port = compute_serving_port(args)
model = New(args.MODEL, args)
model.ensure_model_exists(args)
except KeyError as e:
@@ -1177,6 +1191,11 @@ def run_cli(args):
except Exception as exc:
raise e from exc
is_api_transport = isinstance(model, APITransport)
if args.rag and is_api_transport:
raise ValueError("ramalama run --rag is not supported for hosted API transports.")
if args.rag:
if not args.container:
raise ValueError("ramalama run --rag cannot be run with the --nocontainer option.")
@@ -1184,7 +1203,9 @@ def run_cli(args):
model = RagTransport(model, assemble_command(args.model_args), args)
model.ensure_model_exists(args)
model.run(args, assemble_command(args))
server_cmd = [] if isinstance(model, APITransport) else assemble_command(args)
model.run(args, server_cmd)
def serve_parser(subparsers):
@@ -1224,6 +1245,9 @@ def serve_cli(args):
except Exception:
raise e
if isinstance(model, APITransport):
raise ValueError("ramalama serve is not supported for hosted API transports.")
if args.rag:
if not args.container:
raise ValueError("ramalama serve --rag cannot be run with the --nocontainer option.")

View File

@@ -1,5 +1,3 @@
"""ramalama common module."""
from __future__ import annotations
import glob

View File

@@ -144,6 +144,16 @@ class UserConfig:
self.no_missing_gpu_prompt = coerce_to_bool(self.no_missing_gpu_prompt)
@dataclass
class OpenaiProviderConfig:
api_key: str | None = None
@dataclass
class ProviderConfig:
openai: OpenaiProviderConfig = field(default_factory=OpenaiProviderConfig)
@dataclass
class RamalamaSettings:
"""These settings are not managed directly by the user"""
@@ -253,6 +263,7 @@ class BaseConfig:
gguf_quantization_mode: GGUF_QUANTIZATION_MODES = DEFAULT_GGUF_QUANTIZATION_MODE
http_client: HTTPClientConfig = field(default_factory=HTTPClientConfig)
log_level: LogLevel | None = None
provider: ProviderConfig = field(default_factory=ProviderConfig)
def __post_init__(self):
self.container = coerce_to_bool(self.container) if self.container is not None else self.engine is not None

View File

@@ -4,6 +4,7 @@ from string import Template
from typing import Type
from warnings import warn
from ramalama.chat_utils import AttachmentPart, ChatMessageType, ImageURLPart, UserMessage
from ramalama.file_loaders.file_types import base, image, txt
@@ -115,17 +116,18 @@ class OpanAIChatAPIMessageBuilder:
def supported_extensions(self) -> set[str]:
return self.text_manager.loaders.keys() | self.image_manager.loaders.keys()
def load(self, file_path: str) -> list[dict]:
def load(self, file_path: str) -> list[ChatMessageType]:
text_files, image_files, unsupported_files = self.partition_files(file_path)
if unsupported_files:
unsupported_files_warning(unsupported_files, list(self.supported_extensions()))
messages: list[dict] = []
messages: list[ChatMessageType] = []
if text_files:
messages.append({"role": "system", "content": self.text_manager.load(text_files)})
messages.append(UserMessage(text=self.text_manager.load(text_files)))
if image_files:
content = [{"type": "image_url", "image_url": {"url": c}} for c in self.image_manager.load(image_files)]
message = {"role": "system", "content": content}
messages.append(message)
attachments: list[AttachmentPart] = []
for data_url in self.image_manager.load(image_files):
attachments.append(ImageURLPart(url=data_url))
messages.append(UserMessage(attachments=attachments))
return messages

View File

@@ -124,7 +124,7 @@ class Stack:
'--ctx-size',
str(self.args.context),
'--temp',
self.args.temp,
str(self.args.temp),
'--jinja',
'--cache-reuse',
'256',

View File

@@ -1,11 +1,21 @@
from .huggingface import Huggingface, HuggingfaceRepository
from .modelscope import ModelScope, ModelScopeRepository
from .oci import OCI
from .ollama import Ollama, OllamaRepository
from .rlcr import RamalamaContainerRegistry
from .url import URL
from ramalama.transports import api, huggingface, modelscope, oci, ollama, rlcr, transport_factory, url
from ramalama.transports.api import APITransport
from ramalama.transports.huggingface import Huggingface, HuggingfaceRepository
from ramalama.transports.modelscope import ModelScope, ModelScopeRepository
from ramalama.transports.oci import OCI
from ramalama.transports.ollama import Ollama, OllamaRepository
from ramalama.transports.rlcr import RamalamaContainerRegistry
from ramalama.transports.url import URL
__all__ = [
"api",
"huggingface",
"oci",
"modelscope",
"ollama",
"rlcr",
"transport_factory",
"url",
"Huggingface",
"HuggingfaceRepository",
"ModelScope",
@@ -15,4 +25,5 @@ __all__ = [
"OllamaRepository",
"RamalamaContainerRegistry",
"URL",
"APITransport",
]

108
ramalama/transports/api.py Normal file
View File

@@ -0,0 +1,108 @@
from typing import Any
from ramalama.chat import chat
from ramalama.chat_providers.base import ChatProvider, ChatProviderError
from ramalama.common import perror
from ramalama.transports.base import TransportBase
class APITransport(TransportBase):
"""Transport that proxies chat requests to a hosted API provider."""
type: str = "api"
def __init__(self, model: str, provider: ChatProvider):
self.model = model
self.provider = provider
self._model_tag = "latest"
self._model_name = self.model
self.draft_model = None
@property
def model_name(self) -> str:
return self.model
@property
def model_tag(self) -> str:
return self._model_tag
@property
def model_organization(self) -> str:
return self.provider.provider
@property
def model_type(self) -> str:
return self.type
def _get_entry_model_path(self, use_container: bool, should_generate: bool, dry_run: bool) -> str:
raise NotImplementedError(
f"{self.model} is provided over a hosted API preventing direct pulling of the model file."
)
def _get_mmproj_path(self, use_container: bool, should_generate: bool, dry_run: bool):
return None
def _get_chat_template_path(self, use_container: bool, should_generate: bool, dry_run: bool):
return None
def remove(self, args):
raise NotImplementedError("Hosted API transports do not support removing remote models.")
def bench(self, args, cmd: list[str]):
raise NotImplementedError("bench is not supported for hosted API transports.")
def run(self, args, server_cmd: list[str]):
"""Connect directly to the provider instead of launching a local server."""
args.container = False
args.engine = None
args.model = self.model
if getattr(args, "url", None):
self.provider.base_url = args.url
if getattr(args, "api_key", None):
self.provider.api_key = args.api_key
chat(args, provider=self.provider)
def perplexity(self, args, cmd: list[str]):
raise NotImplementedError("perplexity is not supported for hosted API transports.")
def serve(self, args, cmd: list[str]):
raise NotImplementedError("Hosted API transports cannot be served locally.")
def exists(self) -> bool:
return True
def inspect(self, args):
return {
"provider": self.provider.provider,
"model": self.model_name,
"base_url": self.provider.base_url,
}
def ensure_model_exists(self, args):
args.container = False
args.engine = None
if not self.provider.api_key:
raise ValueError(
f'Missing API key for provider "{self.provider.provider}". '
"Set RAMALAMA_API_KEY or ramalama.provider.openai.api_key."
)
try:
models = self.provider.list_models()
except ChatProviderError as exc:
raise ValueError(str(exc)) from exc
except Exception as exc:
raise RuntimeError(f'Failed to list models for provider "{self.provider.provider}"') from exc
if self.model not in models:
available = ", ".join(models) if models else "none"
raise ValueError(
f'Model "{self.model}" not available from provider "{self.provider.provider}". '
f"Available models: {available}"
)
def pull(self, args: Any):
perror(f"{self.model} is provided over a hosted API preventing direct pulling of the model file.")

View File

@@ -4,10 +4,12 @@ from typing import TypeAlias
from urllib.parse import urlparse
from ramalama.arg_types import StoreArgType
from ramalama.chat_providers.api_providers import get_chat_provider
from ramalama.common import rm_until_substring
from ramalama.config import CONFIG
from ramalama.path_utils import file_uri_to_path
from ramalama.transports.base import MODEL_TYPES
from ramalama.transports.api import APITransport
from ramalama.transports.base import MODEL_TYPES, Transport
from ramalama.transports.huggingface import Huggingface
from ramalama.transports.modelscope import ModelScope
from ramalama.transports.oci import OCI
@@ -15,7 +17,7 @@ from ramalama.transports.ollama import Ollama
from ramalama.transports.rlcr import RamalamaContainerRegistry
from ramalama.transports.url import URL
CLASS_MODEL_TYPES: TypeAlias = Huggingface | Ollama | OCI | URL | ModelScope | RamalamaContainerRegistry
CLASS_MODEL_TYPES: TypeAlias = Huggingface | Ollama | OCI | URL | ModelScope | RamalamaContainerRegistry | APITransport
class TransportFactory:
@@ -26,6 +28,7 @@ class TransportFactory:
transport: str = "ollama",
ignore_stderr: bool = False,
):
self.model = model
self.store_path = args.store
self.transport = transport
@@ -39,41 +42,45 @@ class TransportFactory:
self._create = _create
self.pruned_model = self.prune_model_input()
self.draft_model = None
self.draft_model: Transport | None = None
if getattr(args, 'model_draft', None):
model_draft = getattr(args, "model_draft", None)
if model_draft:
dm_args = copy.deepcopy(args)
dm_args.model_draft = None # type: ignore
self.draft_model = TransportFactory(args.model_draft, dm_args, ignore_stderr=True).create() # type: ignore
draft_model = TransportFactory(model_draft, dm_args, ignore_stderr=True).create()
if not isinstance(draft_model, Transport):
raise ValueError("Draft models must be local transports; hosted API transports are not supported.")
self.draft_model = draft_model
def detect_model_model_type(self) -> tuple[type[CLASS_MODEL_TYPES], Callable[[], CLASS_MODEL_TYPES]]:
for prefix in ["huggingface://", "hf://", "hf.co/"]:
if self.model.startswith(prefix):
match self.model:
case model if model.startswith(("huggingface://", "hf://", "hf.co/")):
return Huggingface, self.create_huggingface
for prefix in ["modelscope://", "ms://"]:
if self.model.startswith(prefix):
case model if model.startswith(("modelscope://", "ms://")):
return ModelScope, self.create_modelscope
for prefix in ["ollama://", "ollama.com/library/"]:
if self.model.startswith(prefix):
case model if model.startswith(("ollama://", "ollama.com/library/")):
return Ollama, self.create_ollama
for prefix in ["oci://", "docker://"]:
if self.model.startswith(prefix):
case model if model.startswith(("oci://", "docker://")):
return OCI, self.create_oci
if self.model.startswith("rlcr://"):
return RamalamaContainerRegistry, self.create_rlcr
for prefix in ["http://", "https://", "file:"]:
if self.model.startswith(prefix):
case model if model.startswith("rlcr://"):
return RamalamaContainerRegistry, self.create_rlcr
case model if model.startswith(("http://", "https://", "file:")):
return URL, self.create_url
if self.transport == "huggingface":
return Huggingface, self.create_huggingface
if self.transport == "modelscope":
return ModelScope, self.create_modelscope
if self.transport == "ollama":
return Ollama, self.create_ollama
if self.transport == "rlcr":
return RamalamaContainerRegistry, self.create_rlcr
if self.transport == "oci":
return OCI, self.create_oci
case model if model.startswith(("openai://")):
return APITransport, self.create_api_transport
match self.transport:
case "huggingface":
return Huggingface, self.create_huggingface
case "modelscope":
return ModelScope, self.create_modelscope
case "ollama":
return Ollama, self.create_ollama
case "rlcr":
return RamalamaContainerRegistry, self.create_rlcr
case "oci":
return OCI, self.create_oci
raise KeyError(f'transport "{self.transport}" not supported. Must be oci, huggingface, modelscope, or ollama.')
@@ -155,6 +162,10 @@ class TransportFactory:
model.draft_model = self.draft_model
return model
def create_api_transport(self) -> APITransport:
scheme = self.model.split("://", 1)[0]
return APITransport(self.pruned_model, provider=get_chat_provider(scheme))
def New(name, args, transport: str | None = None) -> CLASS_MODEL_TYPES:
if transport is None:

View File

@@ -61,3 +61,4 @@
"smolvlm:500m" = "hf://ggml-org/SmolVLM-500M-Instruct-GGUF"
"tiny" = "hf://TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
"tinyllama" = "hf://TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
"gpt-5.1" = "openai://gpt-5.1-2025-11-13"

View File

@@ -0,0 +1 @@
# Package marker for provider-specific tests.

View File

@@ -0,0 +1,146 @@
import json
import pytest
from ramalama.chat_providers.base import ChatRequestOptions
from ramalama.chat_providers.openai import OpenAICompletionsChatProvider, OpenAIResponsesChatProvider
from ramalama.chat_utils import AssistantMessage, ImageURLPart, ToolCall, ToolMessage, UserMessage
def build_payload(content):
return {
"choices": [
{
"delta": {
"content": content,
}
}
]
}
def make_options(**overrides):
data = {"model": "test-model", "stream": True}
data.update(overrides)
return ChatRequestOptions(**data)
class OpenAICompletionsProviderTests:
def setup_method(self):
self.provider = OpenAICompletionsChatProvider("http://example.com")
def test_extracts_string_and_structured_deltas(self):
assert self.provider._extract_delta(build_payload("hello")) == "hello"
structured = build_payload(
[
{"type": "text", "text": "hello"},
{"type": "output_text", "text": " world"},
]
)
assert self.provider._extract_delta(structured) == "hello world"
def test_streaming_handles_structured_chunks(self):
chunk = (
b"data: "
+ json.dumps(
build_payload(
[
{"type": "output_text", "text": "Hi"},
{"type": "output_text", "text": " there"},
]
)
).encode("utf-8")
+ b"\n\n"
)
events = list(self.provider.parse_stream_chunk(chunk))
assert len(events) == 1
assert events[0].text == "Hi there"
def test_rejects_attachments(self):
message = UserMessage(attachments=[ImageURLPart(url="http://img")])
with pytest.raises(ValueError):
self.provider.build_payload([message], make_options())
def test_serializes_tool_calls_and_responses(self):
tool_call = ToolCall(id="call-1", name="lookup", arguments={"query": "weather"})
assistant = AssistantMessage(tool_calls=[tool_call])
tool_reply = ToolMessage(text="72F and sunny", tool_call_id="call-1")
payload = self.provider.build_payload([assistant, tool_reply], make_options())
messages = payload["messages"]
assert messages[0]["tool_calls"][0]["function"]["name"] == "lookup"
assert messages[0]["tool_calls"][0]["function"]["arguments"] == '{"query": "weather"}'
assert messages[1]["tool_call_id"] == "call-1"
class OpenAIResponsesProviderTests:
def setup_method(self):
self.provider = OpenAIResponsesChatProvider("http://example.com")
def test_serializes_structured_content(self):
message = UserMessage(
text="hello",
attachments=[ImageURLPart(url="http://img", detail="high")],
)
payload = self.provider.build_payload([message], make_options(max_tokens=128))
serialized = payload["input"][0]["content"]
assert serialized[0] == {"type": "input_text", "text": "hello"}
assert serialized[1]["type"] == "image_url"
assert serialized[1]["image_url"] == {"url": "http://img", "detail": "high"}
assert payload["max_completion_tokens"] == 128
assert "max_tokens" not in payload
def test_streaming_emits_delta_and_completion_events(self):
chunk = (
b"event: response.output_text.delta\n"
b'data: {"type":"response.output_text.delta","delta":{"text":"Hi"}}\n\n'
b"event: response.completed\n"
b'data: {"type":"response.completed"}\n\n'
)
events = list(self.provider.parse_stream_chunk(chunk))
assert events[0].text == "Hi"
assert events[1].done is True
def test_serializes_tool_calls(self):
tool_call = ToolCall(id="call-9", name="lookup", arguments={"city": "NYC"})
assistant = AssistantMessage(tool_calls=[tool_call])
tool_reply = ToolMessage(text="Clear skies", tool_call_id="call-9")
payload = self.provider.build_payload([assistant, tool_reply], make_options())
first_input = payload["input"][0]
assert first_input["tool_calls"][0]["function"]["name"] == "lookup"
assert first_input["tool_calls"][0]["function"]["arguments"] == '{"city": "NYC"}'
assert payload["input"][1]["tool_call_id"] == "call-9"
def test_streaming_emits_done_event_for_done_marker(self):
events = list(self.provider.parse_stream_chunk(b"data: [DONE]\n\n"))
assert len(events) == 1
assert events[0].done is True
def test_streaming_ignores_invalid_json_chunks(self):
events = list(self.provider.parse_stream_chunk(b"data: {invalid-json\n\n"))
assert events == []
def test_streaming_extracts_text_from_done_events(self):
chunk = (
b"event: response.output_text.done\n"
b'data: {"type":"response.output_text.done","output":[{"content":'
b'[{"type":"output_text","text":"All done"}]}]}\n\n'
)
events = list(self.provider.parse_stream_chunk(chunk))
assert len(events) == 1
assert events[0].text == "All done"

View File

@@ -0,0 +1,17 @@
import pytest
from ramalama.chat_providers.api_providers import get_chat_provider
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
def test_get_chat_provider_returns_openai_provider():
provider = get_chat_provider("openai")
assert isinstance(provider, OpenAIResponsesChatProvider)
assert provider.base_url == "https://api.openai.com/v1"
assert provider.provider == "openai"
def test_get_chat_provider_raises_for_unknown_scheme():
with pytest.raises(ValueError):
get_chat_provider("anthropic")

View File

@@ -0,0 +1,89 @@
from types import SimpleNamespace
import pytest
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
from ramalama.config import CONFIG
from ramalama.transports import api as api_module
from ramalama.transports.api import APITransport
def make_provider(api_key: str = "provider-default") -> OpenAIResponsesChatProvider:
return OpenAIResponsesChatProvider("https://api.openai.com/v1", api_key=api_key)
def test_api_transport_run(monkeypatch):
provider = make_provider()
transport = APITransport("gpt-4o-mini", provider)
recorded: dict[str, object] = {}
def fake_chat(args, operational_args=None, provider=None):
recorded["args"] = args
recorded["operational_args"] = operational_args
recorded["provider"] = provider
monkeypatch.setattr(api_module, "chat", fake_chat)
args = SimpleNamespace(
container=True, engine="podman", url="http://localhost", model=None, api="none", api_key=None
)
transport.run(args, [])
assert args.container is False
assert args.engine is None
assert args.url == provider.base_url
assert args.model == "gpt-4o-mini"
assert recorded["args"] is args
assert recorded["provider"] is provider
assert provider.base_url == "http://localhost"
assert provider.api_key == "provider-default"
def test_api_transport_ensure_exists_mutates_args(monkeypatch):
provider = make_provider()
transport = APITransport("gpt-4", provider)
args = SimpleNamespace(container=True, engine="podman")
monkeypatch.setattr(provider, "list_models", lambda: ["gpt-4", "other"])
transport.ensure_model_exists(args)
assert args.container is False
assert args.engine is None
def test_api_transport_ensure_exists_requires_api_key(monkeypatch):
monkeypatch.setattr(CONFIG, "api_key", None)
provider = make_provider(api_key=None)
transport = APITransport("gpt-4", provider)
args = SimpleNamespace(container=True, engine="podman")
with pytest.raises(ValueError, match="Missing API key"):
transport.ensure_model_exists(args)
def test_api_transport_overrides_provider_api_key(monkeypatch):
provider = make_provider()
transport = APITransport("gpt-4o-mini", provider)
recorded: dict[str, object] = {}
def fake_chat(args, operational_args=None, provider=None):
recorded["provider"] = provider
monkeypatch.setattr(api_module, "chat", fake_chat)
args = SimpleNamespace(container=True, engine="podman", url=None, model=None, api="none", api_key="cli-secret")
transport.run(args, [])
assert provider.api_key == "cli-secret"
assert recorded["provider"] is provider
def test_api_transport_ensure_exists_raises_if_model_missing(monkeypatch):
provider = make_provider()
transport = APITransport("gpt-4", provider)
monkeypatch.setattr(provider, "list_models", lambda: ["gpt-3.5"])
args = SimpleNamespace(container=True, engine="podman")
with pytest.raises(ValueError):
transport.ensure_model_exists(args)

View File

@@ -0,0 +1,33 @@
import io
import urllib.error
import pytest
import ramalama.chat_providers.base as base_module
from ramalama.chat_providers.base import ChatProviderError
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
def test_list_models_reports_auth_error(monkeypatch):
provider = OpenAIResponsesChatProvider("https://api.openai.com/v1", api_key="bad")
error_body = b'{"error":{"message":"Invalid API key"}}'
http_error = urllib.error.HTTPError(
provider.build_url("/models"),
401,
"Unauthorized",
{},
io.BytesIO(error_body),
)
def fake_urlopen(request):
raise http_error
monkeypatch.setattr(base_module.urllib_request, "urlopen", fake_urlopen)
with pytest.raises(ChatProviderError) as excinfo:
provider.list_models()
message = str(excinfo.value)
assert "Could not authenticate with openai." in message
assert "missing or invalid" in message
assert "Invalid API key" in message

View File

@@ -51,6 +51,7 @@ parser = get_parser()
special_cases = {
"api_key": "api-key",
"max_tokens": "max-tokens",
}
@@ -115,6 +116,10 @@ def test_default_endpoint(chatargs):
ChatSubArgs,
prefix=st.sampled_from(['> ', '🦙 > ', '🦭 > ', '🐋 > ']),
url=st.sampled_from(['https://test.com', 'test.com']),
temp=st.one_of(
st.none(),
st.floats(min_value=0, allow_nan=False, allow_infinity=False).map(lambda v: 0.0 if v == 0 else v),
),
)
)
def test_chat_endpoint(chatargs):

View File

@@ -4,12 +4,21 @@ from unittest.mock import mock_open, patch
import pytest
from ramalama.chat_utils import ImageURLPart
from ramalama.file_loaders.file_manager import ImageFileManager, OpanAIChatAPIMessageBuilder, TextFileManager
from ramalama.file_loaders.file_types.base import BaseFileLoader
from ramalama.file_loaders.file_types.image import BasicImageFileLoader
from ramalama.file_loaders.file_types.txt import TXTFileLoader
def _text_content(message):
return message.text or ""
def _image_parts(message):
return [attachment for attachment in message.attachments if isinstance(attachment, ImageURLPart)]
class TestBaseFileLoader:
"""Test the abstract base class for file upload handlers."""
@@ -269,9 +278,10 @@ class TestOpanAIChatAPIMessageBuilder:
messages = builder.load(tmp_file.name)
assert len(messages) == 1
assert messages[0]["role"] == "system"
assert "Test content" in messages[0]["content"]
assert f"<!--start_document {tmp_file.name}-->" in messages[0]["content"]
assert messages[0].role == "user"
content = _text_content(messages[0])
assert "Test content" in content
assert f"<!--start_document {tmp_file.name}-->" in content
def test_builder_load_image_files_only(self):
"""Test loading only image files."""
@@ -283,12 +293,10 @@ class TestOpanAIChatAPIMessageBuilder:
messages = builder.load(tmp_file.name)
assert len(messages) == 1
assert messages[0]["role"] == "system"
assert isinstance(messages[0]["content"], list)
assert len(messages[0]["content"]) == 1
assert 'image_url' in messages[0]["content"][0]
assert 'url' in messages[0]["content"][0]["image_url"]
assert "data:image/" in messages[0]["content"][0]["image_url"]["url"]
assert messages[0].role == "user"
image_parts = _image_parts(messages[0])
assert len(image_parts) == 1
assert "data:image/" in image_parts[0].url
def test_builder_load_mixed_files(self):
"""Test loading mixed text and image files."""
@@ -306,12 +314,11 @@ class TestOpanAIChatAPIMessageBuilder:
assert len(messages) == 2
# First message should be text
assert messages[0]["role"] == "system"
assert "Text content" in messages[0]["content"]
assert messages[0].role == "user"
assert "Text content" in _text_content(messages[0])
# Second message should be image
assert messages[1]["role"] == "system"
assert isinstance(messages[1]["content"], list)
assert len(messages[1]["content"]) == 1
assert messages[1].role == "user"
assert len(_image_parts(messages[1])) == 1
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
def test_builder_load_no_supported_files(self):
@@ -405,7 +412,7 @@ class TestFileUploadIntegration:
messages = builder.load(tmp_dir)
assert len(messages) == 1
content = messages[0]["content"]
content = _text_content(messages[0])
for file_content in files_content.values():
assert file_content in content

View File

@@ -5,6 +5,11 @@ from unittest.mock import MagicMock, patch
import pytest
from ramalama.chat import RamaLamaShell, chat
from ramalama.chat_utils import ImageURLPart
def _text_content(message):
return message.text or ""
class TestFileUploadChatIntegration:
@@ -31,10 +36,11 @@ class TestFileUploadChatIntegration:
# Check that the system message was added to conversation history
assert len(shell.conversation_history) == 1
system_message = shell.conversation_history[0]
assert system_message["role"] == "system"
assert "This is test content for chat input" in system_message["content"]
assert f"<!--start_document {tmp_file.name}-->" in system_message["content"]
message = shell.conversation_history[0]
assert message.role == "user"
content = message.text or ""
assert "This is test content for chat input" in content
assert f"<!--start_document {tmp_file.name}-->" in content
@patch('urllib.request.urlopen')
def test_chat_with_file_input_directory(self, mock_urlopen):
@@ -62,13 +68,14 @@ class TestFileUploadChatIntegration:
# Check that the system message was added to conversation history
assert len(shell.conversation_history) == 1
system_message = shell.conversation_history[0]
assert system_message["role"] == "system"
assert "Text file content" in system_message["content"]
assert "# Markdown Content" in system_message["content"]
assert "test.txt" in system_message["content"]
assert "readme.md" in system_message["content"]
assert "<!--start_document" in system_message["content"]
message = shell.conversation_history[0]
assert message.role == "user"
content = message.text or ""
assert "Text file content" in content
assert "# Markdown Content" in content
assert "test.txt" in content
assert "readme.md" in content
assert "<!--start_document" in content
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
@patch('urllib.request.urlopen')
@@ -125,11 +132,11 @@ class TestFileUploadChatIntegration:
# Check that the system message was added to conversation history
assert len(shell.conversation_history) == 1
system_message = shell.conversation_history[0]
assert system_message["role"] == "system"
assert f"<!--start_document {tmp_file.name}-->" in system_message["content"]
# Empty file should still have the delimiter but no content
assert system_message["content"].endswith(f"\n<!--start_document {tmp_file.name}-->\n")
message = shell.conversation_history[0]
assert message.role == "user"
text = _text_content(message)
assert f"<!--start_document {tmp_file.name}-->" in text
assert text.endswith(f"\n<!--start_document {tmp_file.name}-->\n")
@patch('urllib.request.urlopen')
def test_chat_with_file_input_unicode_content(self, mock_urlopen):
@@ -153,10 +160,11 @@ class TestFileUploadChatIntegration:
# Check that the system message was added to conversation history
assert len(shell.conversation_history) == 1
system_message = shell.conversation_history[0]
assert system_message["role"] == "system"
assert unicode_content in system_message["content"]
assert f"<!--start_document {tmp_file.name}-->" in system_message["content"]
message = shell.conversation_history[0]
assert message.role == "user"
text = message.text or ""
assert unicode_content in text
assert f"<!--start_document {tmp_file.name}-->" in text
@patch('urllib.request.urlopen')
def test_chat_with_file_input_mixed_content_types(self, mock_urlopen):
@@ -188,15 +196,16 @@ class TestFileUploadChatIntegration:
# Check that the system message was added to conversation history
assert len(shell.conversation_history) == 1
system_message = shell.conversation_history[0]
assert system_message["role"] == "system"
assert "English content" in system_message["content"]
assert '{"key": "value", "number": 42}' in system_message["content"]
assert "setting: enabled" in system_message["content"]
assert "values:" in system_message["content"]
assert "english.txt" in system_message["content"]
assert "data.json" in system_message["content"]
assert "config.yaml" in system_message["content"]
message = shell.conversation_history[0]
assert message.role == "user"
text = _text_content(message)
assert "English content" in text
assert '{"key": "value", "number": 42}' in text
assert "setting: enabled" in text
assert "values:" in text
assert "english.txt" in text
assert "data.json" in text
assert "config.yaml" in text
@patch('urllib.request.urlopen')
def test_chat_with_file_input_no_input_specified(self, mock_urlopen):
@@ -237,10 +246,11 @@ class TestFileUploadChatIntegration:
# Check that the system message was added to conversation history
assert len(shell.conversation_history) == 1
system_message = shell.conversation_history[0]
assert system_message["role"] == "system"
assert "File content" in system_message["content"]
assert f"<!--start_document {tmp_file.name}-->" in system_message["content"]
message = shell.conversation_history[0]
assert message.role == "user"
text = _text_content(message)
assert "File content" in text
assert f"<!--start_document {tmp_file.name}-->" in text
def test_chat_function_with_rag_and_dryrun(self):
"""Test that chat function works correctly with rag and dryrun."""
@@ -290,14 +300,13 @@ class TestImageUploadChatIntegration:
# Check that the system message was added to conversation history
assert len(shell.conversation_history) == 1
system_message = shell.conversation_history[0]
assert system_message["role"] == "system"
assert isinstance(system_message["content"], list)
assert len(system_message["content"]) == 1
assert 'image_url' in system_message["content"][0]
assert 'url' in system_message["content"][0]["image_url"]
assert "data:image/" in system_message["content"][0]["image_url"]["url"]
assert "base64," in system_message["content"][0]["image_url"]["url"]
message = shell.conversation_history[0]
assert message.role == "user"
assert len(message.attachments) == 1
part = message.attachments[0]
assert isinstance(part, ImageURLPart)
assert part.url.startswith("data:image/")
assert "base64," in part.url
@patch('urllib.request.urlopen')
def test_chat_with_image_input_directory(self, mock_urlopen):
@@ -325,14 +334,13 @@ class TestImageUploadChatIntegration:
# Check that the system message was added to conversation history
assert len(shell.conversation_history) == 1
system_message = shell.conversation_history[0]
assert system_message["role"] == "system"
assert isinstance(system_message["content"], list)
assert len(system_message["content"]) == 2
assert all('image_url' in item for item in system_message["content"])
assert all('url' in item["image_url"] for item in system_message["content"])
assert all("data:image/" in item["image_url"]["url"] for item in system_message["content"])
assert all("base64," in item["image_url"]["url"] for item in system_message["content"])
message = shell.conversation_history[0]
assert message.role == "user"
assert len(message.attachments) == 2
for part in message.attachments:
assert isinstance(part, ImageURLPart)
assert "data:image/" in part.url
assert "base64," in part.url
@patch('urllib.request.urlopen')
def test_chat_with_image_input_mixed_file_types(self, mock_urlopen):
@@ -359,30 +367,22 @@ class TestImageUploadChatIntegration:
shell = RamaLamaShell(mock_args)
# Check that two system messages were added to conversation history
system_messages = [msg for msg in shell.conversation_history if msg["role"] == "system"]
assert len(system_messages) == 2
user_messages = [msg for msg in shell.conversation_history if msg.role == "user"]
assert len(user_messages) == 2
# Determine which message is text and which is image
if isinstance(system_messages[0]["content"], str):
text_msg = system_messages[0]
image_msg = system_messages[1]
if user_messages[0].attachments:
image_msg = user_messages[0]
text_msg = user_messages[1]
else:
text_msg = system_messages[1]
image_msg = system_messages[0]
text_msg = user_messages[0]
image_msg = user_messages[1]
# Assert text message content
assert "Text content" in text_msg["content"]
assert "readme.txt" in text_msg["content"]
text = _text_content(text_msg)
assert "Text content" in text
assert "readme.txt" in text
# Assert image message content
assert isinstance(image_msg["content"], list)
assert any(
isinstance(item, dict)
and "image_url" in item
and "url" in item["image_url"]
and "data:image/" in item["image_url"]["url"]
for item in image_msg["content"]
)
assert any(isinstance(part, ImageURLPart) for part in image_msg.attachments)
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
@patch('urllib.request.urlopen')
@@ -434,11 +434,10 @@ class TestImageUploadChatIntegration:
# Check that the system message was added to conversation history
assert len(shell.conversation_history) == 1
system_message = shell.conversation_history[0]
assert system_message["role"] == "system"
assert isinstance(system_message["content"], list)
assert len(system_message["content"]) == 2
assert all('image_url' in item for item in system_message["content"])
assert all('url' in item["image_url"] for item in system_message["content"])
assert all("data:image/" in item["image_url"]["url"] for item in system_message["content"])
assert all("base64," in item["image_url"]["url"] for item in system_message["content"])
message = shell.conversation_history[0]
assert message.role == "user"
assert len(message.attachments) == 2
for part in message.attachments:
assert isinstance(part, ImageURLPart)
assert "data:image/" in part.url
assert "base64," in part.url

View File

@@ -4,9 +4,18 @@ from pathlib import Path
import pytest
from ramalama.chat_utils import ImageURLPart
from ramalama.file_loaders.file_manager import OpanAIChatAPIMessageBuilder
def _text_content(message):
return message.text or ""
def _image_parts(message):
return [attachment for attachment in message.attachments if isinstance(attachment, ImageURLPart)]
class TestFileUploadWithDataFiles:
"""Test file upload functionality using sample data files."""
@@ -24,10 +33,11 @@ class TestFileUploadWithDataFiles:
messages = builder.load(str(txt_file))
assert len(messages) == 1
assert "This is a sample text file" in messages[0]["content"]
assert "TXTFileUpload class" in messages[0]["content"]
assert "Special characters like: !@#$%^&*()" in messages[0]["content"]
assert f"<!--start_document {txt_file}-->" in messages[0]["content"]
content = _text_content(messages[0])
assert "This is a sample text file" in content
assert "TXTFileUpload class" in content
assert "Special characters like: !@#$%^&*()" in content
assert f"<!--start_document {txt_file}-->" in content
def test_load_single_markdown_file(self, data_dir):
"""Test loading a single markdown file from the data directory."""
@@ -37,11 +47,12 @@ class TestFileUploadWithDataFiles:
messages = builder.load(str(md_file))
assert len(messages) == 1
assert "# Sample Markdown File" in messages[0]["content"]
assert "**Bold text** and *italic text*" in messages[0]["content"]
assert "```python" in messages[0]["content"]
assert "def hello_world():" in messages[0]["content"]
assert f"<!--start_document {md_file}-->" in messages[0]["content"]
content = _text_content(messages[0])
assert "# Sample Markdown File" in content
assert "**Bold text** and *italic text*" in content
assert "```python" in content
assert "def hello_world():" in content
assert f"<!--start_document {md_file}-->" in content
def test_load_single_json_file(self, data_dir):
"""Test loading a single JSON file from the data directory."""
@@ -51,11 +62,12 @@ class TestFileUploadWithDataFiles:
messages = builder.load(str(json_file))
assert len(messages) == 1
assert '"name": "test_data"' in messages[0]["content"]
assert '"version": "1.0.0"' in messages[0]["content"]
assert '"text_processing"' in messages[0]["content"]
assert '"supported_formats"' in messages[0]["content"]
assert f"<!--start_document {json_file}-->" in messages[0]["content"]
content = _text_content(messages[0])
assert '"name": "test_data"' in content
assert '"version": "1.0.0"' in content
assert '"text_processing"' in content
assert '"supported_formats"' in content
assert f"<!--start_document {json_file}-->" in content
def test_load_single_yaml_file(self, data_dir):
"""Test loading a single YAML file from the data directory."""
@@ -65,12 +77,13 @@ class TestFileUploadWithDataFiles:
messages = builder.load(str(yaml_file))
assert len(messages) == 1
assert "name: test_config" in messages[0]["content"]
assert "version: 1.0.0" in messages[0]["content"]
assert "- text_processing" in messages[0]["content"]
assert "- yaml_support" in messages[0]["content"]
assert "deep:" in messages[0]["content"]
assert f"<!--start_document {yaml_file}-->" in messages[0]["content"]
content = _text_content(messages[0])
assert "name: test_config" in content
assert "version: 1.0.0" in content
assert "- text_processing" in content
assert "- yaml_support" in content
assert "deep:" in content
assert f"<!--start_document {yaml_file}-->" in content
def test_load_single_csv_file(self, data_dir):
"""Test loading a single CSV file from the data directory."""
@@ -80,11 +93,12 @@ class TestFileUploadWithDataFiles:
messages = builder.load(str(csv_file))
assert len(messages) == 1
assert "name,age,city,occupation" in messages[0]["content"]
assert "John Doe,30,New York,Engineer" in messages[0]["content"]
assert "Jane Smith,25,San Francisco,Designer" in messages[0]["content"]
assert "Bob Johnson,35,Chicago,Manager" in messages[0]["content"]
assert f"<!--start_document {csv_file}-->" in messages[0]["content"]
content = _text_content(messages[0])
assert "name,age,city,occupation" in content
assert "John Doe,30,New York,Engineer" in content
assert "Jane Smith,25,San Francisco,Designer" in content
assert "Bob Johnson,35,Chicago,Manager" in content
assert f"<!--start_document {csv_file}-->" in content
def test_load_single_toml_file(self, data_dir):
"""Test loading a single TOML file from the data directory."""
@@ -94,12 +108,13 @@ class TestFileUploadWithDataFiles:
messages = builder.load(str(toml_file))
assert len(messages) == 1
assert 'name = "test_config"' in messages[0]["content"]
assert 'version = "1.0.0"' in messages[0]["content"]
assert 'text_processing = true' in messages[0]["content"]
assert 'toml_support = true' in messages[0]["content"]
assert 'with_deep_nesting = true' in messages[0]["content"]
assert f"<!--start_document {toml_file}-->" in messages[0]["content"]
content = _text_content(messages[0])
assert 'name = "test_config"' in content
assert 'version = "1.0.0"' in content
assert 'text_processing = true' in content
assert 'toml_support = true' in content
assert 'with_deep_nesting = true' in content
assert f"<!--start_document {toml_file}-->" in content
def test_load_single_shell_script(self, data_dir):
"""Test loading a single shell script from the data directory."""
@@ -109,12 +124,13 @@ class TestFileUploadWithDataFiles:
messages = builder.load(str(sh_file))
assert len(messages) == 1
assert "#!/bin/bash" in messages[0]["content"]
assert "Hello, World! This is a test script." in messages[0]["content"]
assert "test_function()" in messages[0]["content"]
assert "for i in {1..3}" in messages[0]["content"]
assert "Script completed successfully!" in messages[0]["content"]
assert f"<!--start_document {sh_file}-->" in messages[0]["content"]
content = _text_content(messages[0])
assert "#!/bin/bash" in content
assert "Hello, World! This is a test script." in content
assert "test_function()" in content
assert "for i in {1..3}" in content
assert "Script completed successfully!" in content
assert f"<!--start_document {sh_file}-->" in content
def test_load_entire_data_directory(self, data_dir):
"""Test loading all files from the data directory."""
@@ -122,7 +138,7 @@ class TestFileUploadWithDataFiles:
messages = builder.load(str(data_dir))
assert len(messages) == 1
content = messages[0]["content"]
content = _text_content(messages[0])
assert "This is a sample text file" in content # sample.txt
assert "# Sample Markdown File" in content # sample.md
assert '"name": "test_data"' in content # sample.json
@@ -151,7 +167,7 @@ class TestFileUploadWithDataFiles:
messages = builder.load(str(txt_file))
assert len(messages) == 1
content = messages[0]["content"]
content = _text_content(messages[0])
content_start = content.find('\n', content.find('<!--start_document')) + 1
extracted_content = content[content_start:]
@@ -172,7 +188,7 @@ class TestFileUploadWithDataFiles:
messages = builder.load(tmp_dir)
assert len(messages) == 1
content = messages[0]["content"]
content = _text_content(messages[0])
assert "This is a sample text file" in content # sample.txt
assert "# Sample Markdown File" in content # sample.md
assert '"name": "test_data"' in content # sample.json
@@ -200,7 +216,7 @@ class TestFileUploadWithDataFiles:
messages = builder.load(tmp_dir)
assert len(messages) == 1
content = messages[0]["content"]
content = _text_content(messages[0])
assert "This is a sample text file" in content
assert "This is an unsupported file type" not in content
assert "sample.txt" in content
@@ -226,12 +242,10 @@ class TestImageUploadWithDataFiles:
messages = builder.load(tmp_file.name)
assert len(messages) == 1
assert isinstance(messages[0]["content"], list)
assert len(messages[0]["content"]) == 1
assert 'image_url' in messages[0]["content"][0]
assert 'url' in messages[0]["content"][0]["image_url"]
assert "data:image/" in messages[0]["content"][0]["image_url"]["url"]
assert "base64," in messages[0]["content"][0]["image_url"]["url"]
image_parts = _image_parts(messages[0])
assert len(image_parts) == 1
assert "data:image/" in image_parts[0].url
assert "base64," in image_parts[0].url
def test_load_multiple_image_files(self, data_dir):
"""Test loading multiple image files."""
@@ -252,12 +266,10 @@ class TestImageUploadWithDataFiles:
messages = builder.load(tmp_dir)
assert len(messages) == 1
assert isinstance(messages[0]["content"], list)
assert len(messages[0]["content"]) == 3
assert all('image_url' in item for item in messages[0]["content"])
assert all('url' in item["image_url"] for item in messages[0]["content"])
assert all("data:image/" in item["image_url"]["url"] for item in messages[0]["content"])
assert all("base64," in item["image_url"]["url"] for item in messages[0]["content"])
image_parts = _image_parts(messages[0])
assert len(image_parts) == 3
assert all("data:image/" in part.url for part in image_parts)
assert all("base64," in part.url for part in image_parts)
def test_image_file_content_integrity(self, data_dir):
"""Test that image file content is preserved exactly."""
@@ -270,11 +282,11 @@ class TestImageUploadWithDataFiles:
messages = builder.load(tmp_file.name)
assert len(messages) == 1
assert isinstance(messages[0]["content"], list)
assert len(messages[0]["content"]) == 1
image_parts = _image_parts(messages[0])
assert len(image_parts) == 1
# Extract base64 data from result
url = messages[0]["content"][0]["image_url"]["url"]
url = image_parts[0].url
base64_data = url.split("base64,")[1]
import base64
@@ -304,12 +316,10 @@ class TestImageUploadWithDataFiles:
messages = builder.load(tmp_dir)
assert len(messages) == 1
assert isinstance(messages[0]["content"], list)
assert len(messages[0]["content"]) == 8
assert all('image_url' in item for item in messages[0]["content"])
assert all('url' in item["image_url"] for item in messages[0]["content"])
assert all("data:image/" in item["image_url"]["url"] for item in messages[0]["content"])
assert all("base64," in item["image_url"]["url"] for item in messages[0]["content"])
image_parts = _image_parts(messages[0])
assert len(image_parts) == 8
assert all("data:image/" in part.url for part in image_parts)
assert all("base64," in part.url for part in image_parts)
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
def test_image_unsupported_file_handling(self, data_dir):
@@ -327,12 +337,10 @@ class TestImageUploadWithDataFiles:
messages = builder.load(tmp_dir)
assert len(messages) == 1
assert isinstance(messages[0]["content"], list)
assert len(messages[0]["content"]) == 1
assert 'image_url' in messages[0]["content"][0]
assert 'url' in messages[0]["content"][0]["image_url"]
assert "data:image/" in messages[0]["content"][0]["image_url"]["url"]
assert "base64," in messages[0]["content"][0]["image_url"]["url"]
image_parts = _image_parts(messages[0])
assert len(image_parts) == 1
assert "data:image/" in image_parts[0].url
assert "base64," in image_parts[0].url
def test_image_case_insensitive_extensions(self, data_dir):
"""Test that image file extensions are handled case-insensitively."""
@@ -357,9 +365,7 @@ class TestImageUploadWithDataFiles:
messages = builder.load(tmp_dir)
assert len(messages) == 1
assert isinstance(messages[0]["content"], list)
assert len(messages[0]["content"]) == 8
assert all('image_url' in item for item in messages[0]["content"])
assert all('url' in item["image_url"] for item in messages[0]["content"])
assert all("data:image/" in item["image_url"]["url"] for item in messages[0]["content"])
assert all("base64," in item["image_url"]["url"] for item in messages[0]["content"])
image_parts = _image_parts(messages[0])
assert len(image_parts) == 8
assert all("data:image/" in part.url for part in image_parts)
assert all("base64," in part.url for part in image_parts)

View File

@@ -3,6 +3,9 @@ from typing import Union
import pytest
import ramalama.transports.transport_factory as transport_factory_module
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
from ramalama.transports.api import APITransport
from ramalama.transports.huggingface import Huggingface
from ramalama.transports.modelscope import ModelScope
from ramalama.transports.oci import OCI
@@ -35,6 +38,7 @@ hf_granite_blob = "https://huggingface.co/ibm-granite/granite-3b-code-base-2k-GG
"input,expected,error",
[
(Input("", "", ""), None, KeyError),
(Input("openai://gpt-4o-mini", "", ""), APITransport, None),
(Input("huggingface://granite-code", "", ""), Huggingface, None),
(Input("hf://granite-code", "", ""), Huggingface, None),
(Input("hf.co/granite-code", "", ""), Huggingface, None),
@@ -113,6 +117,7 @@ def test_validate_oci_model_input(input: Input, error):
@pytest.mark.parametrize(
"input,expected",
[
(Input("openai://gpt-4o-mini", "", ""), "gpt-4o-mini"),
(Input("huggingface://granite-code", "", ""), "granite-code"),
(
Input("huggingface://ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""),
@@ -164,3 +169,21 @@ def test_prune_model_input(input: Input, expected: str):
args = ARGS(input.Engine)
pruned_model_input = TransportFactory(input.Model, args, input.Transport).prune_model_input()
assert pruned_model_input == expected
def test_transport_factory_passes_scheme_to_get_chat_provider(monkeypatch):
args = ARGS()
provider = OpenAIResponsesChatProvider("https://api.openai.com/v1")
captured: dict[str, str] = {}
def fake_get_chat_provider(scheme: str):
captured["scheme"] = scheme
return provider
monkeypatch.setattr(transport_factory_module, "get_chat_provider", fake_get_chat_provider)
transport = TransportFactory("openai://gpt-4o-mini", args).create()
assert captured["scheme"] == "openai"
assert isinstance(transport, APITransport)
assert transport.provider is provider