mirror of
https://github.com/containers/ramalama.git
synced 2026-02-05 06:46:39 +01:00
adds support for hosted chat providers
Signed-off-by: Ian Eaves <ian.k.eaves@gmail.com>
This commit is contained in:
2
Makefile
2
Makefile
@@ -118,7 +118,7 @@ endif
|
||||
.PHONY: lint
|
||||
lint:
|
||||
ifneq (,$(wildcard /usr/bin/python3))
|
||||
/usr/bin/python3 -m compileall -q -x '\.venv' .
|
||||
${PYTHON} -m compileall -q -x '\.venv' .
|
||||
endif
|
||||
! grep -ri $(EXCLUDE_OPTS) "#\!/usr/bin/python3" .
|
||||
flake8 $(FLAKE8_ARGS) $(PROJECT_DIR) $(PYTHON_SCRIPTS)
|
||||
|
||||
@@ -29,6 +29,9 @@ Show this help message and exit
|
||||
#### **--list**
|
||||
List the available models at an endpoint
|
||||
|
||||
#### **--max-tokens**=*integer*
|
||||
Maximum number of tokens to generate. Set to 0 for unlimited output (default: 0).
|
||||
|
||||
#### **--mcp**=SERVER_URL
|
||||
MCP (Model Context Protocol) servers to use for enhanced tool calling capabilities.
|
||||
Can be specified multiple times to connect to multiple MCP servers.
|
||||
@@ -49,6 +52,10 @@ When enabled, ramalama will periodically condense older messages into a summary,
|
||||
keeping only recent messages and the summary. This prevents the context from growing
|
||||
indefinitely during long chat sessions. Set to 0 to disable (default: 4).
|
||||
|
||||
#### **--temp**=*float*
|
||||
Temperature of the response from the AI Model.
|
||||
Lower numbers are more deterministic, higher numbers are more creative.
|
||||
|
||||
#### **--url**=URL
|
||||
The host to send requests to (default: http://127.0.0.1:8080)
|
||||
|
||||
|
||||
@@ -17,13 +17,17 @@ ramalama\-run - run specified AI Model as a chatbot
|
||||
| rlcr | rlcr:// | [`ramalama.com`](https://registry.ramalama.com) |
|
||||
| OCI Container Registries | oci:// | [`opencontainers.org`](https://opencontainers.org)|
|
||||
|||Examples: [`quay.io`](https://quay.io), [`Docker Hub`](https://docker.io),[`Artifactory`](https://artifactory.com)|
|
||||
| Hosted API Providers | openai:// | [`api.openai.com`](https://api.openai.com)|
|
||||
|
||||
RamaLama defaults to the Ollama registry transport. This default can be overridden in the `ramalama.conf` file or via the RAMALAMA_TRANSPORTS
|
||||
environment. `export RAMALAMA_TRANSPORT=huggingface` Changes RamaLama to use huggingface transport.
|
||||
|
||||
Modify individual model transports by specifying the `huggingface://`, `oci://`, `ollama://`, `https://`, `http://`, `file://` prefix to the model.
|
||||
Modify individual model transports by specifying the `huggingface://`, `oci://`, `ollama://`, `https://`, `http://`, `file://`, or hosted API
|
||||
prefix (`openai://`).
|
||||
|
||||
URL support means if a model is on a web site or even on your local system, you can run it directly.
|
||||
Hosted API transports connect directly to the remote provider and bypass the local container runtime. In this mode, flags that tune local
|
||||
containers (for example `--image`, GPU settings, or `--network`) do not apply, and the provider's own capabilities and security posture govern
|
||||
the execution. URL support means if a model is on a web site or even on your local system, you can run it directly.
|
||||
|
||||
## OPTIONS
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ RamaLama CLI defaults can be modified via ramalama.conf files. Default settings
|
||||
|
||||
### Test and run your models more securely
|
||||
|
||||
Because RamaLama defaults to running AI models inside of rootless containers using Podman on Docker. These containers isolate the AI models from information on the underlying host. With RamaLama containers, the AI model is mounted as a volume into the container in read/only mode. This results in the process running the model, llama.cpp or vLLM, being isolated from the host. In addition, since `ramalama run` uses the --network=none option, the container can not reach the network and leak any information out of the system. Finally, containers are run with --rm options which means that any content written during the running of the container is wiped out when the application exits.
|
||||
Because RamaLama defaults to running AI models inside of rootless containers using Podman on Docker. These containers isolate the AI models from information on the underlying host. With RamaLama containers, the AI model is mounted as a volume into the container in read/only mode. This results in the process running the model, llama.cpp or vLLM, being isolated from the host. In addition, since `ramalama run` uses the --network=none option, the container can not reach the network and leak any information out of the system. Finally, containers are run with --rm options which means that any content written during the running of the container is wiped out when the application exits. Hosted API transports such as `openai://` bypass the container runtime entirely and connect directly to the remote provider; those transports inherit the provider's network access and security guarantees instead of RamaLama's container sandbox.
|
||||
|
||||
### Here’s how RamaLama delivers a robust security footprint:
|
||||
|
||||
|
||||
@@ -205,12 +205,26 @@
|
||||
# The maximum delay between retry attempts in seconds.
|
||||
#
|
||||
#max_retry_delay = 30
|
||||
|
||||
|
||||
[ramalama.provider]
|
||||
# Provider-specific hosted API configuration. Set per-provider options in the
|
||||
# nested tables below.
|
||||
|
||||
|
||||
[ramalama.provider.openai]
|
||||
# Optional provider-specific API key used when contacting OpenAI-hosted
|
||||
# transports. If unset, RamaLama falls back to the RAMALAMA_API_KEY value
|
||||
# or environment variables expected by the provider.
|
||||
#
|
||||
#api_key = "sk-..."
|
||||
|
||||
|
||||
[ramalama.user]
|
||||
# Suppress the interactive prompt when running on macOS with a Podman VM
|
||||
# that doesn't support GPU acceleration (e.g., applehv provider).
|
||||
# When set to true, RamaLama will automatically proceed without GPU support
|
||||
# instead of asking for confirmation.
|
||||
# Can also be set via the `RAMALAMA_USER__NO_MISSING_GPU_PROMPT` environment variable.
|
||||
#
|
||||
|
||||
[ramalama.user]
|
||||
#no_missing_gpu_prompt = false
|
||||
|
||||
@@ -253,6 +253,21 @@ The maximum number of times to retry a failed download
|
||||
|
||||
The maximum delay between retry attempts in seconds
|
||||
|
||||
## RAMALAMA.PROVIDER TABLE
|
||||
The `ramalama.provider` table configures hosted API providers that RamaLama can proxy to.
|
||||
|
||||
`[[ramalama.provider]]`
|
||||
|
||||
**openai**=""
|
||||
|
||||
Configuration settings for the openai hosted provider
|
||||
|
||||
`[[ramalama.provider.openai]]`
|
||||
|
||||
**api_key**=""
|
||||
|
||||
Provider-specific API key used when invoking OpenAI-hosted transports. Overrides `RAMALAMA_API_KEY` when set.
|
||||
|
||||
## RAMALAMA.USER TABLE
|
||||
The ramalama.user table contains user preference settings.
|
||||
|
||||
|
||||
152
docsite/package-lock.json
generated
152
docsite/package-lock.json
generated
@@ -27,6 +27,76 @@
|
||||
"node": ">=18.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@ai-sdk/gateway": {
|
||||
"version": "2.0.24",
|
||||
"resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-2.0.24.tgz",
|
||||
"integrity": "sha512-mflk80YF8hj8vrF9e1IHhovGKC1ubX+sY88pesSk3pUiXfH5VPO8dgzNnxjwsqsCZrnkHcztxS5cSl4TzSiEuA==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@ai-sdk/provider": "2.0.1",
|
||||
"@ai-sdk/provider-utils": "3.0.20",
|
||||
"@vercel/oidc": "3.0.5"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"zod": "^3.25.76 || ^4.1.8"
|
||||
}
|
||||
},
|
||||
"node_modules/@ai-sdk/provider": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.1.tgz",
|
||||
"integrity": "sha512-KCUwswvsC5VsW2PWFqF8eJgSCu5Ysj7m1TxiHTVA6g7k360bk0RNQENT8KTMAYEs+8fWPD3Uu4dEmzGHc+jGng==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"json-schema": "^0.4.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@ai-sdk/provider-utils": {
|
||||
"version": "3.0.20",
|
||||
"resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-3.0.20.tgz",
|
||||
"integrity": "sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@ai-sdk/provider": "2.0.1",
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"eventsource-parser": "^3.0.6"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"zod": "^3.25.76 || ^4.1.8"
|
||||
}
|
||||
},
|
||||
"node_modules/@ai-sdk/react": {
|
||||
"version": "2.0.119",
|
||||
"resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-2.0.119.tgz",
|
||||
"integrity": "sha512-kl4CDAnKJ1z+Fc9cjwMQXLRqH5/gHhg8Jn9qW7sZ0LgL8VpiDmW+x+s8e588nE3eC88aL1OxOVyOE6lFYfWprw==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@ai-sdk/provider-utils": "3.0.20",
|
||||
"ai": "5.0.117",
|
||||
"swr": "^2.2.5",
|
||||
"throttleit": "2.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^18 || ~19.0.1 || ~19.1.2 || ^19.2.1",
|
||||
"zod": "^3.25.76 || ^4.1.8"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"zod": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@algolia/abtesting": {
|
||||
"version": "1.13.0",
|
||||
"resolved": "https://registry.npmjs.org/@algolia/abtesting/-/abtesting-1.13.0.tgz",
|
||||
@@ -2832,9 +2902,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@csstools/postcss-normalize-display-values": {
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@csstools/postcss-normalize-display-values/-/postcss-normalize-display-values-4.0.1.tgz",
|
||||
"integrity": "sha512-TQUGBuRvxdc7TgNSTevYqrL8oItxiwPDixk20qCB5me/W8uF7BPbhRrAvFuhEoywQp/woRsUZ6SJ+sU5idZAIA==",
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/@csstools/postcss-normalize-display-values/-/postcss-normalize-display-values-4.0.0.tgz",
|
||||
"integrity": "sha512-HlEoG0IDRoHXzXnkV4in47dzsxdsjdz6+j7MLjaACABX2NfvjFS6XVAnpaDyGesz9gK2SC7MbNwdCHusObKJ9Q==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "github",
|
||||
@@ -5274,9 +5344,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@types/express-serve-static-core": {
|
||||
"version": "4.19.8",
|
||||
"resolved": "https://registry.npmjs.org/@types/express-serve-static-core/-/express-serve-static-core-4.19.8.tgz",
|
||||
"integrity": "sha512-02S5fmqeoKzVZCHPZid4b8JH2eM5HzQLZWN2FohQEy/0eXTq8VXZfSN6Pcr3F6N9R/vNrj7cpgbhjie6m/1tCA==",
|
||||
"version": "4.19.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/express-serve-static-core/-/express-serve-static-core-4.19.7.tgz",
|
||||
"integrity": "sha512-FvPtiIf1LfhzsaIXhv/PHan/2FeQBbtBDtfX2QfvPxdUelMDEckK08SM6nqo1MIZY3RUlfA+HV8+hFUSio78qg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/node": "*",
|
||||
@@ -5391,9 +5461,15 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/node": {
|
||||
<<<<<<< HEAD
|
||||
"version": "25.0.10",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-25.0.10.tgz",
|
||||
"integrity": "sha512-zWW5KPngR/yvakJgGOmZ5vTBemDoSqF3AcV/LrO5u5wTWyEAVVh+IT39G4gtyAkh3CtTZs8aX/yRM82OfzHJRg==",
|
||||
=======
|
||||
"version": "25.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-25.0.3.tgz",
|
||||
"integrity": "sha512-W609buLVRVmeW693xKfzHeIV6nJGGz98uCPfeXI1ELMLXVeKYZ9m15fAMSaUPBHYLGFsVRcMmSCksQOrZV9BYA==",
|
||||
>>>>>>> ca824d98 (adds support for hosted chat providers)
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"undici-types": "~7.16.0"
|
||||
@@ -5418,9 +5494,15 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/react": {
|
||||
<<<<<<< HEAD
|
||||
"version": "19.2.9",
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.9.tgz",
|
||||
"integrity": "sha512-Lpo8kgb/igvMIPeNV2rsYKTgaORYdO1XGVZ4Qz3akwOj0ySGYMPlQWa8BaLn0G63D1aSaAQ5ldR06wCpChQCjA==",
|
||||
=======
|
||||
"version": "19.2.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.7.tgz",
|
||||
"integrity": "sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg==",
|
||||
>>>>>>> ca824d98 (adds support for hosted chat providers)
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"csstype": "^3.2.2"
|
||||
@@ -5825,6 +5907,27 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/ai": {
|
||||
"version": "5.0.117",
|
||||
"resolved": "https://registry.npmjs.org/ai/-/ai-5.0.117.tgz",
|
||||
"integrity": "sha512-uE6HNkdSwxbeHGKP/YbvapwD8fMOpj87wyfT9Z00pbzOh2fpnw5acak/4kzU00SX2vtI9K0uuy+9Tf9ytw5RwA==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@ai-sdk/gateway": "2.0.24",
|
||||
"@ai-sdk/provider": "2.0.1",
|
||||
"@ai-sdk/provider-utils": "3.0.20",
|
||||
"@opentelemetry/api": "1.9.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"zod": "^3.25.76 || ^4.1.8"
|
||||
}
|
||||
},
|
||||
>>>>>>> ca824d98 (adds support for hosted chat providers)
|
||||
"node_modules/ajv": {
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
|
||||
@@ -6189,9 +6292,15 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/baseline-browser-mapping": {
|
||||
<<<<<<< HEAD
|
||||
"version": "2.9.17",
|
||||
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.17.tgz",
|
||||
"integrity": "sha512-agD0MgJFUP/4nvjqzIB29zRPUuCF7Ge6mEv9s8dHrtYD7QWXRcx75rOADE/d5ah1NI+0vkDl0yorDd5U852IQQ==",
|
||||
=======
|
||||
"version": "2.9.11",
|
||||
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.11.tgz",
|
||||
"integrity": "sha512-Sg0xJUNDU1sJNGdfGWhVHX0kkZ+HWcvmVymJbj6NSgZZmW/8S9Y2HQ5euytnIgakgxN6papOAWiwDo1ctFDcoQ==",
|
||||
>>>>>>> ca824d98 (adds support for hosted chat providers)
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"baseline-browser-mapping": "dist/cli.js"
|
||||
@@ -6522,9 +6631,15 @@
|
||||
}
|
||||
},
|
||||
"node_modules/caniuse-lite": {
|
||||
<<<<<<< HEAD
|
||||
"version": "1.0.30001765",
|
||||
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001765.tgz",
|
||||
"integrity": "sha512-LWcNtSyZrakjECqmpP4qdg0MMGdN368D7X8XvvAqOcqMv0RxnlqVKZl2V6/mBR68oYMxOZPLw/gO7DuisMHUvQ==",
|
||||
=======
|
||||
"version": "1.0.30001762",
|
||||
"resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001762.tgz",
|
||||
"integrity": "sha512-PxZwGNvH7Ak8WX5iXzoK1KPZttBXNPuaOvI2ZYU7NrlM+d9Ov+TUvlLOBNGzVXAntMSMMlJPd+jY6ovrVjSmUw==",
|
||||
>>>>>>> ca824d98 (adds support for hosted chat providers)
|
||||
"funding": [
|
||||
{
|
||||
"type": "opencollective",
|
||||
@@ -14784,9 +14899,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/postcss-preset-env": {
|
||||
"version": "10.6.1",
|
||||
"resolved": "https://registry.npmjs.org/postcss-preset-env/-/postcss-preset-env-10.6.1.tgz",
|
||||
"integrity": "sha512-yrk74d9EvY+W7+lO9Aj1QmjWY9q5NsKjK2V9drkOPZB/X6KZ0B3igKsHUYakb7oYVhnioWypQX3xGuePf89f3g==",
|
||||
"version": "10.6.0",
|
||||
"resolved": "https://registry.npmjs.org/postcss-preset-env/-/postcss-preset-env-10.6.0.tgz",
|
||||
"integrity": "sha512-+LzpUSLCGHUdlZ1YZP7lp7w1MjxInJRSG0uaLyk/V/BM17iU2B7xTO7I8x3uk0WQAcLLh/ffqKzOzfaBvG7Fdw==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "github",
|
||||
@@ -14824,7 +14939,7 @@
|
||||
"@csstools/postcss-media-minmax": "^2.0.9",
|
||||
"@csstools/postcss-media-queries-aspect-ratio-number-values": "^3.0.5",
|
||||
"@csstools/postcss-nested-calc": "^4.0.0",
|
||||
"@csstools/postcss-normalize-display-values": "^4.0.1",
|
||||
"@csstools/postcss-normalize-display-values": "^4.0.0",
|
||||
"@csstools/postcss-oklab-function": "^4.0.12",
|
||||
"@csstools/postcss-position-area-property": "^1.0.0",
|
||||
"@csstools/postcss-progressive-custom-properties": "^4.2.1",
|
||||
@@ -16132,13 +16247,10 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/sax": {
|
||||
"version": "1.4.4",
|
||||
"resolved": "https://registry.npmjs.org/sax/-/sax-1.4.4.tgz",
|
||||
"integrity": "sha512-1n3r/tGXO6b6VXMdFT54SHzT9ytu9yr7TaELowdYpMqY/Ao7EnlQGmAQ1+RatX7Tkkdm6hONI2owqNx2aZj5Sw==",
|
||||
"license": "BlueOak-1.0.0",
|
||||
"engines": {
|
||||
"node": ">=11.0.0"
|
||||
}
|
||||
"version": "1.4.3",
|
||||
"resolved": "https://registry.npmjs.org/sax/-/sax-1.4.3.tgz",
|
||||
"integrity": "sha512-yqYn1JhPczigF94DMS+shiDMjDowYO6y9+wB/4WgO0Y19jWYk0lQ4tuG5KI7kj4FTp1wxPj5IFfcrz/s1c3jjQ==",
|
||||
"license": "BlueOak-1.0.0"
|
||||
},
|
||||
"node_modules/scheduler": {
|
||||
"version": "0.27.0",
|
||||
@@ -18042,9 +18154,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/webpack-dev-server/node_modules/ws": {
|
||||
"version": "8.19.0",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz",
|
||||
"integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==",
|
||||
"version": "8.18.3",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz",
|
||||
"integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
import sys
|
||||
|
||||
from ramalama import cli
|
||||
from ramalama.cli import HelpException, init_cli, print_version
|
||||
from ramalama.common import perror
|
||||
|
||||
assert sys.version_info >= (3, 10), "Python 3.10 or greater is required."
|
||||
|
||||
__all__ = ["perror", "init_cli", "print_version", "HelpException"]
|
||||
__all__ = ["cli", "perror", "init_cli", "print_version", "HelpException"]
|
||||
|
||||
@@ -82,6 +82,8 @@ class ChatSubArgsType(Protocol):
|
||||
rag: str | None
|
||||
api_key: str | None
|
||||
ARGS: List[str] | None
|
||||
max_tokens: int | None
|
||||
temp: float | None
|
||||
|
||||
|
||||
ChatSubArgs = protocol_to_dataclass(ChatSubArgsType)
|
||||
|
||||
321
ramalama/chat.py
321
ramalama/chat.py
@@ -5,7 +5,6 @@ import cmd
|
||||
import copy
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -13,13 +12,24 @@ import threading
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
from ramalama.arg_types import ChatArgsType
|
||||
from ramalama.chat_providers import ChatProvider, ChatRequestOptions
|
||||
from ramalama.chat_providers.openai import OpenAICompletionsChatProvider
|
||||
from ramalama.chat_utils import (
|
||||
AssistantMessage,
|
||||
ChatMessageType,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
stream_response,
|
||||
)
|
||||
from ramalama.common import perror
|
||||
from ramalama.config import CONFIG
|
||||
from ramalama.console import EMOJI, should_colorize
|
||||
from ramalama.console import should_colorize
|
||||
from ramalama.engine import stop_container
|
||||
from ramalama.file_loaders.file_manager import OpanAIChatAPIMessageBuilder
|
||||
from ramalama.logger import logger
|
||||
@@ -31,69 +41,6 @@ from ramalama.proxy_support import setup_proxy_support
|
||||
setup_proxy_support()
|
||||
|
||||
|
||||
def res(response, color):
|
||||
color_default = ""
|
||||
color_yellow = ""
|
||||
if (color == "auto" and should_colorize()) or color == "always":
|
||||
color_default = "\033[0m"
|
||||
color_yellow = "\033[33m"
|
||||
|
||||
print("\r", end="")
|
||||
assistant_response = ""
|
||||
for line in response:
|
||||
line = line.decode("utf-8").strip()
|
||||
if line.startswith("data: {"):
|
||||
choice = ""
|
||||
|
||||
json_line = json.loads(line[len("data: ") :])
|
||||
if "choices" in json_line and json_line["choices"]:
|
||||
choice = json_line["choices"][0]["delta"]
|
||||
if "content" in choice:
|
||||
choice = choice["content"]
|
||||
else:
|
||||
continue
|
||||
|
||||
if choice:
|
||||
print(f"{color_yellow}{choice}{color_default}", end="", flush=True)
|
||||
assistant_response += choice
|
||||
|
||||
print("")
|
||||
return assistant_response
|
||||
|
||||
|
||||
def default_prefix():
|
||||
if not EMOJI:
|
||||
return "> "
|
||||
|
||||
if CONFIG.prefix:
|
||||
return CONFIG.prefix
|
||||
|
||||
engine = CONFIG.engine
|
||||
|
||||
if engine:
|
||||
if os.path.basename(engine) == "podman":
|
||||
return "🦭 > "
|
||||
|
||||
if os.path.basename(engine) == "docker":
|
||||
return "🐋 > "
|
||||
|
||||
return "🦙 > "
|
||||
|
||||
|
||||
def add_api_key(args, headers=None):
|
||||
# static analyzers suggest for dict, this is a safer way of setting
|
||||
# a default value, rather than using the parameter directly
|
||||
headers = headers or {}
|
||||
if getattr(args, "api_key", None):
|
||||
api_key_min = 20
|
||||
if len(args.api_key) < api_key_min:
|
||||
perror("Warning: Provided API key is invalid.")
|
||||
|
||||
headers["Authorization"] = f"Bearer {args.api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatOperationalArgs:
|
||||
initial_connection: bool = False
|
||||
@@ -102,18 +49,70 @@ class ChatOperationalArgs:
|
||||
monitor: "ServerMonitor | None" = None
|
||||
|
||||
|
||||
class Spinner:
|
||||
def __init__(self, wait_time: float = 0.1):
|
||||
self._stop_event: threading.Event = threading.Event()
|
||||
self._thread: threading.Thread | None = None
|
||||
self.wait_time = wait_time
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
def start(self) -> "Spinner":
|
||||
if not sys.stdout.isatty():
|
||||
return self
|
||||
|
||||
if self._thread is not None:
|
||||
self.stop()
|
||||
|
||||
self._thread = threading.Thread(target=self._spinner_loop, daemon=True)
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
if self._thread is None:
|
||||
return
|
||||
|
||||
self._stop_event.set()
|
||||
self._thread.join(timeout=0.2)
|
||||
perror("\r", end="", flush=True)
|
||||
self._thread = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def _spinner_loop(self):
|
||||
frames = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
||||
|
||||
for frame in itertools.cycle(frames):
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
perror(f"\r{frame}", end="", flush=True)
|
||||
self._stop_event.wait(self.wait_time)
|
||||
|
||||
|
||||
class RamaLamaShell(cmd.Cmd):
|
||||
def __init__(self, args: ChatArgsType, operational_args: ChatOperationalArgs | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
args: ChatArgsType,
|
||||
operational_args: ChatOperationalArgs | None = None,
|
||||
provider: ChatProvider | None = None,
|
||||
):
|
||||
if operational_args is None:
|
||||
operational_args = ChatOperationalArgs()
|
||||
|
||||
super().__init__()
|
||||
self.conversation_history: list[dict] = []
|
||||
self.conversation_history: list[ChatMessageType] = []
|
||||
self.args = args
|
||||
self.operational_args = operational_args
|
||||
self.request_in_process = False
|
||||
self.prompt = args.prefix
|
||||
self.url = f"{args.url}/chat/completions"
|
||||
self.provider = provider or OpenAICompletionsChatProvider(args.url, getattr(args, "api_key", None))
|
||||
self.url = self.provider.build_url()
|
||||
|
||||
self.prep_rag_message()
|
||||
self.mcp_agent: LLMAgent | None = None
|
||||
self.initialize_mcp()
|
||||
@@ -131,7 +130,7 @@ class RamaLamaShell(cmd.Cmd):
|
||||
|
||||
def _summarize_conversation(self):
|
||||
"""Summarize the conversation history to prevent context growth."""
|
||||
if len(self.conversation_history) < 4:
|
||||
if len(self.conversation_history) < 10:
|
||||
# Need at least a few messages to summarize
|
||||
return
|
||||
|
||||
@@ -145,16 +144,14 @@ class RamaLamaShell(cmd.Cmd):
|
||||
return
|
||||
|
||||
# Create a summarization prompt
|
||||
conversation_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_to_summarize])
|
||||
|
||||
summary_prompt = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Please provide a concise summary of the following conversation, "
|
||||
conversation_text = "\n".join([self._format_message_for_summary(msg) for msg in messages_to_summarize])
|
||||
summary_prompt = UserMessage(
|
||||
text=(
|
||||
"Please provide a concise summary of the following conversation, "
|
||||
f"preserving key information and context:\n\n{conversation_text}\n\n"
|
||||
f"Provide only the summary, without any preamble."
|
||||
),
|
||||
}
|
||||
"Provide only the summary, without any preamble."
|
||||
)
|
||||
)
|
||||
|
||||
# Make API call to get summary
|
||||
# Provide user feedback during summarization
|
||||
@@ -167,12 +164,12 @@ class RamaLamaShell(cmd.Cmd):
|
||||
summary = result['choices'][0]['message']['content']
|
||||
|
||||
# Rebuild conversation history with summary
|
||||
new_history = []
|
||||
new_history: list[ChatMessageType] = []
|
||||
if first_msg:
|
||||
new_history.append(first_msg)
|
||||
|
||||
# Add summary as a system message
|
||||
new_history.append({"role": "system", "content": f"Previous conversation summary: {summary}"})
|
||||
new_history.append(SystemMessage(text=f"Previous conversation summary: {summary}"))
|
||||
|
||||
# Add recent messages
|
||||
new_history.extend(recent_msgs)
|
||||
@@ -196,35 +193,42 @@ class RamaLamaShell(cmd.Cmd):
|
||||
self._summarize_conversation()
|
||||
self.message_count = 0 # Reset counter after summarization
|
||||
|
||||
def _make_api_request(self, messages, stream=True):
|
||||
"""Make an API request with the given messages.
|
||||
def _history_snapshot(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": msg.role, "content": self._format_message_for_summary(msg)} for msg in self.conversation_history
|
||||
]
|
||||
|
||||
Args:
|
||||
messages: List of message dicts to send
|
||||
stream: Whether to stream the response
|
||||
def _format_message_for_summary(self, msg: ChatMessageType) -> str:
|
||||
content = msg.text or ""
|
||||
if isinstance(msg, AssistantMessage):
|
||||
if msg.tool_calls:
|
||||
content += f"\n[tool_calls: {', '.join(call.name for call in msg.tool_calls)}]"
|
||||
|
||||
Returns:
|
||||
urllib.request.Request object
|
||||
"""
|
||||
data = {
|
||||
"stream": stream,
|
||||
"messages": messages,
|
||||
}
|
||||
if getattr(self.args, "model", None):
|
||||
data["model"] = self.args.model
|
||||
if getattr(self.args, "temp", None):
|
||||
data["temperature"] = float(self.args.temp)
|
||||
if stream and getattr(self.args, "max_tokens", None):
|
||||
data["max_completion_tokens"] = self.args.max_tokens
|
||||
if isinstance(msg, ToolMessage):
|
||||
content = f"\n[tool_response: {msg.text}]"
|
||||
|
||||
headers = add_api_key(self.args)
|
||||
headers["Content-Type"] = "application/json"
|
||||
return f"{msg.role}: {content}".strip()
|
||||
|
||||
return urllib.request.Request(
|
||||
self.url,
|
||||
data=json.dumps(data).encode('utf-8'),
|
||||
headers=headers,
|
||||
method="POST",
|
||||
def _make_api_request(self, messages: Sequence[ChatMessageType], stream: bool = True):
|
||||
"""Create a provider request for arbitrary message lists."""
|
||||
max_tokens = self.args.max_tokens if stream and getattr(self.args, "max_tokens", None) else None
|
||||
options = self._build_request_options(stream=stream, max_tokens=max_tokens)
|
||||
return self.provider.create_request(messages, options)
|
||||
|
||||
def _resolve_model_name(self) -> str | None:
|
||||
if getattr(self.args, "runtime", None) == "mlx":
|
||||
return None
|
||||
return getattr(self.args, "model", None)
|
||||
|
||||
def _build_request_options(self, *, stream: bool, max_tokens: int | None) -> ChatRequestOptions:
|
||||
temperature = getattr(self.args, "temp", None)
|
||||
if max_tokens is not None and max_tokens <= 0:
|
||||
max_tokens = None
|
||||
return ChatRequestOptions(
|
||||
model=self._resolve_model_name(),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def initialize_mcp(self):
|
||||
@@ -281,7 +285,7 @@ class RamaLamaShell(cmd.Cmd):
|
||||
"""Determine if the request should be handled by MCP tools."""
|
||||
if not self.mcp_agent:
|
||||
return False
|
||||
return self.mcp_agent.should_use_tools(content, self.conversation_history)
|
||||
return self.mcp_agent.should_use_tools(content, self._history_snapshot())
|
||||
|
||||
def _handle_mcp_request(self, content: str) -> str:
|
||||
"""Handle a request using MCP tools (multi-tool capable, automatic)."""
|
||||
@@ -325,8 +329,8 @@ class RamaLamaShell(cmd.Cmd):
|
||||
print(f"\n {r['tool']} -> {r['output']}")
|
||||
|
||||
# Save to history
|
||||
self.conversation_history.append({"role": "user", "content": f"/tool {question}"})
|
||||
self.conversation_history.append({"role": "assistant", "content": str(responses)})
|
||||
self.conversation_history.append(UserMessage(text=f"/tool {question}"))
|
||||
self.conversation_history.append(AssistantMessage(text=str(responses)))
|
||||
|
||||
def _select_tools(self):
|
||||
"""Interactive multi-tool selection without prompting for arguments."""
|
||||
@@ -395,41 +399,26 @@ class RamaLamaShell(cmd.Cmd):
|
||||
# If streaming, _handle_mcp_request already printed output
|
||||
if isinstance(response, str) and response.strip():
|
||||
print(response)
|
||||
self.conversation_history.append({"role": "user", "content": content})
|
||||
self.conversation_history.append({"role": "assistant", "content": response})
|
||||
self.conversation_history.append(UserMessage(text=content))
|
||||
self.conversation_history.append(AssistantMessage(text=response))
|
||||
self._check_and_summarize()
|
||||
return False
|
||||
|
||||
self.conversation_history.append({"role": "user", "content": content})
|
||||
self.conversation_history.append(UserMessage(text=content))
|
||||
self.request_in_process = True
|
||||
response = self._req()
|
||||
if response:
|
||||
self.conversation_history.append({"role": "assistant", "content": response})
|
||||
self.conversation_history.append(AssistantMessage(text=response))
|
||||
self.request_in_process = False
|
||||
self._check_and_summarize()
|
||||
|
||||
def _make_request_data(self):
|
||||
data = {
|
||||
"stream": True,
|
||||
"messages": self.conversation_history,
|
||||
}
|
||||
if getattr(self.args, "temp", None):
|
||||
data["temperature"] = float(self.args.temp)
|
||||
if getattr(self.args, "max_tokens", None):
|
||||
data["max_completion_tokens"] = self.args.max_tokens
|
||||
# For MLX runtime, omit explicit model to allow server default ("default_model")
|
||||
if getattr(self.args, "runtime", None) != "mlx" and self.args.model is not None:
|
||||
data["model"] = self.args.model
|
||||
|
||||
json_data = json.dumps(data).encode("utf-8")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
headers = add_api_key(self.args, headers)
|
||||
logger.debug("Request: URL=%s, Data=%s, Headers=%s", self.url, json_data, headers)
|
||||
request = urllib.request.Request(self.url, data=json_data, headers=headers, method="POST")
|
||||
|
||||
options = self._build_request_options(
|
||||
stream=True,
|
||||
max_tokens=getattr(self.args, "max_tokens", None),
|
||||
)
|
||||
request = self.provider.create_request(self.conversation_history, options)
|
||||
logger.debug("Request: URL=%s, Data=%s, Headers=%s", request.full_url, request.data, request.headers)
|
||||
return request
|
||||
|
||||
def _req(self):
|
||||
@@ -442,28 +431,46 @@ class RamaLamaShell(cmd.Cmd):
|
||||
# Adjust timeout based on whether we're in initial connection phase
|
||||
max_timeout = 30 if getattr(self.args, "initial_connection", False) else 16
|
||||
|
||||
for c in itertools.cycle(['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']):
|
||||
last_error: Exception | None = None
|
||||
|
||||
spinner = Spinner().start()
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = urllib.request.urlopen(request)
|
||||
spinner.stop()
|
||||
break
|
||||
except Exception:
|
||||
if sys.stdout.isatty():
|
||||
perror(f"\r{c}", end="", flush=True)
|
||||
except urllib.error.HTTPError as http_err:
|
||||
error_body = http_err.read().decode("utf-8", "ignore").strip()
|
||||
message = f"HTTP {http_err.code}"
|
||||
if error_body:
|
||||
message = f"{message}: {error_body}"
|
||||
perror(f"\r{message}")
|
||||
|
||||
if total_time_slept > max_timeout:
|
||||
break
|
||||
self.kills()
|
||||
spinner.stop()
|
||||
return None
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
|
||||
total_time_slept += i
|
||||
time.sleep(i)
|
||||
if total_time_slept > max_timeout:
|
||||
break
|
||||
|
||||
i = min(i * 2, 0.1)
|
||||
total_time_slept += i
|
||||
time.sleep(i)
|
||||
|
||||
i = min(i * 2, 0.1)
|
||||
|
||||
spinner.stop()
|
||||
if response:
|
||||
return res(response, self.args.color)
|
||||
return stream_response(response, self.args.color, self.provider)
|
||||
|
||||
# Only show error and kill if not in initial connection phase
|
||||
if not getattr(self.args, "initial_connection", False):
|
||||
perror(f"\rError: could not connect to: {self.url}")
|
||||
error_suffix = ""
|
||||
if last_error:
|
||||
error_suffix = f" ({last_error})"
|
||||
perror(f"\rError: could not connect to: {self.url}{error_suffix}")
|
||||
self.kills()
|
||||
else:
|
||||
logger.debug(f"Could not connect to: {self.url}")
|
||||
@@ -721,12 +728,20 @@ def _report_server_exit(monitor):
|
||||
perror("Check server logs for more details about why the service exited.")
|
||||
|
||||
|
||||
def chat(args: ChatArgsType, operational_args: ChatOperationalArgs | None = None):
|
||||
def chat(
|
||||
args: ChatArgsType,
|
||||
operational_args: ChatOperationalArgs | None = None,
|
||||
provider: ChatProvider | None = None,
|
||||
):
|
||||
if args.dryrun:
|
||||
assert args.ARGS is not None
|
||||
prompt = " ".join(args.ARGS)
|
||||
print(f"\nramalama chat --color {args.color} --prefix \"{args.prefix}\" --url {args.url} {prompt}")
|
||||
return
|
||||
|
||||
if provider is None:
|
||||
provider = OpenAICompletionsChatProvider(args.url, getattr(args, "api_key", None))
|
||||
|
||||
# SIGALRM is Unix-only, skip keepalive timeout handling on Windows
|
||||
if getattr(args, "keepalive", False) and hasattr(signal, 'SIGALRM'):
|
||||
signal.signal(signal.SIGALRM, alarm_handler)
|
||||
@@ -752,14 +767,12 @@ def chat(args: ChatArgsType, operational_args: ChatOperationalArgs | None = None
|
||||
monitor.start()
|
||||
list_models = getattr(args, "list", False)
|
||||
if list_models:
|
||||
url = f"{args.url}/models"
|
||||
headers = add_api_key(args)
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
with urllib.request.urlopen(req) as response:
|
||||
data = json.loads(response.read())
|
||||
ids = [model["id"] for model in data.get("data", [])]
|
||||
for id in ids:
|
||||
print(id)
|
||||
for model_id in provider.list_models():
|
||||
print(model_id)
|
||||
monitor.stop()
|
||||
if hasattr(signal, 'alarm'):
|
||||
signal.alarm(0)
|
||||
return
|
||||
|
||||
# Ensure operational_args is initialized
|
||||
if operational_args is None:
|
||||
@@ -770,7 +783,7 @@ def chat(args: ChatArgsType, operational_args: ChatOperationalArgs | None = None
|
||||
|
||||
successful_exit = True
|
||||
try:
|
||||
shell = RamaLamaShell(args, operational_args)
|
||||
shell = RamaLamaShell(args, operational_args, provider=provider)
|
||||
if shell.handle_args(monitor):
|
||||
return
|
||||
|
||||
|
||||
9
ramalama/chat_providers/__init__.py
Normal file
9
ramalama/chat_providers/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from ramalama.chat_providers import api_providers, openai
|
||||
from ramalama.chat_providers.base import (
|
||||
ChatProvider,
|
||||
ChatProviderError,
|
||||
ChatRequestOptions,
|
||||
ChatStreamEvent,
|
||||
)
|
||||
|
||||
__all__ = ["ChatProvider", "ChatProviderError", "ChatRequestOptions", "ChatStreamEvent", "openai", "api_providers"]
|
||||
33
ramalama/chat_providers/api_providers.py
Normal file
33
ramalama/chat_providers/api_providers.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from ramalama.chat_providers.base import ChatProvider
|
||||
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
|
||||
from ramalama.config import CONFIG
|
||||
|
||||
PROVIDER_API_KEY_RESOLVERS: dict[str, Callable[[], str | None]] = {
|
||||
"openai": lambda: CONFIG.provider.openai.api_key,
|
||||
}
|
||||
|
||||
|
||||
def get_provider_api_key(scheme: str) -> str | None:
|
||||
"""Return a configured API key for the given provider scheme, if any."""
|
||||
|
||||
if resolver := PROVIDER_API_KEY_RESOLVERS.get(scheme):
|
||||
return resolver()
|
||||
return CONFIG.api_key
|
||||
|
||||
|
||||
DEFAULT_PROVIDERS = {
|
||||
"openai": lambda: OpenAIResponsesChatProvider(
|
||||
base_url="https://api.openai.com/v1", api_key=get_provider_api_key("openai")
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def get_chat_provider(scheme: str) -> ChatProvider:
|
||||
if (resolver := DEFAULT_PROVIDERS.get(scheme, None)) is None:
|
||||
raise ValueError(f"No support chat providers for {scheme}")
|
||||
return resolver()
|
||||
|
||||
|
||||
__all__ = ["get_chat_provider", "get_provider_api_key"]
|
||||
201
ramalama/chat_providers/base.py
Normal file
201
ramalama/chat_providers/base.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from urllib import error as urllib_error
|
||||
from urllib import request as urllib_request
|
||||
|
||||
from ramalama.chat_utils import ChatMessageType
|
||||
from ramalama.config import CONFIG
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChatRequestOptions:
|
||||
"""Normalized knobs for building a chat completion request."""
|
||||
|
||||
model: str | None = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
stream: bool = True
|
||||
extra: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
keys = ["model", "temperature", "max_tokens", "stream"]
|
||||
result = {k: v for k in keys if (v := getattr(self, k)) is not None}
|
||||
result |= {} if self.extra is None else dict(self.extra)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChatStreamEvent:
|
||||
"""A provider-agnostic representation of a streamed delta."""
|
||||
|
||||
text: str | None = None
|
||||
raw: dict[str, Any] | None = None
|
||||
done: bool = False
|
||||
|
||||
|
||||
class ChatProviderError(Exception):
|
||||
"""Raised when a provider request fails or returns an invalid payload."""
|
||||
|
||||
def __init__(self, message: str, *, status_code: int | None = None, payload: Any | None = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.payload = payload
|
||||
|
||||
|
||||
class ChatProvider(ABC):
|
||||
"""Abstract base class for hosted chat providers."""
|
||||
|
||||
provider: str = "base"
|
||||
default_path: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
) -> None:
|
||||
if api_key is None:
|
||||
api_key = CONFIG.api_key
|
||||
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self._default_headers: dict[str, str] = dict(default_headers or {})
|
||||
|
||||
def build_url(self, path: str | None = None) -> str:
|
||||
rel = path or self.default_path
|
||||
if not rel.startswith("/"):
|
||||
rel = f"/{rel}"
|
||||
return f"{self.base_url}{rel}"
|
||||
|
||||
def prepare_headers(
|
||||
self,
|
||||
*,
|
||||
include_auth: bool = True,
|
||||
extra: dict[str, str] | None = None,
|
||||
options: ChatRequestOptions | None = None,
|
||||
) -> dict[str, str]:
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
**self._default_headers,
|
||||
**self.provider_headers(options),
|
||||
}
|
||||
|
||||
if include_auth:
|
||||
headers.update(self.auth_headers())
|
||||
if extra:
|
||||
headers.update(extra)
|
||||
return headers
|
||||
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}
|
||||
|
||||
def serialize_payload(self, payload: Mapping[str, Any]) -> bytes:
|
||||
return json.dumps(payload).encode("utf-8")
|
||||
|
||||
def create_request(
|
||||
self, messages: Sequence[ChatMessageType], options: ChatRequestOptions
|
||||
) -> urllib_request.Request:
|
||||
payload = self.build_payload(messages, options)
|
||||
headers = self.prepare_headers(options=options, extra=self.additional_request_headers(options))
|
||||
body = self.serialize_payload(payload)
|
||||
return urllib_request.Request(
|
||||
self.build_url(self.resolve_request_path(options)),
|
||||
data=body,
|
||||
headers=headers,
|
||||
method="POST",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider customization points
|
||||
# ------------------------------------------------------------------
|
||||
def provider_headers(self, options: ChatRequestOptions | None = None) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
def additional_request_headers(self, options: ChatRequestOptions | None = None) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
def resolve_request_path(self, options: ChatRequestOptions | None = None) -> str:
|
||||
return self.default_path
|
||||
|
||||
@abstractmethod
|
||||
def build_payload(self, messages: Sequence[ChatMessageType], options: ChatRequestOptions) -> Mapping[str, Any]:
|
||||
"""Return the provider-specific payload."""
|
||||
|
||||
@abstractmethod
|
||||
def parse_stream_chunk(self, chunk: bytes) -> Iterable[ChatStreamEvent]:
|
||||
"""Yield zero or more events parsed from a streamed response chunk."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Error handling
|
||||
# ------------------------------------------------------------------
|
||||
def raise_for_status(self, status_code: int, payload: Any | None = None) -> None:
|
||||
if status_code >= 400:
|
||||
if isinstance(payload, dict) and "error" in payload:
|
||||
err = payload["error"]
|
||||
message = str(err.get("message") or err.get("type") or err) if isinstance(err, dict) else str(err)
|
||||
else:
|
||||
message = "chat request failed"
|
||||
|
||||
raise ChatProviderError(message, status_code=status_code, payload=payload)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streamed helpers
|
||||
# ------------------------------------------------------------------
|
||||
def parse_response_body(self, body: bytes) -> Any:
|
||||
if not body:
|
||||
return None
|
||||
return json.loads(body.decode("utf-8"))
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""Return available model identifiers exposed by the provider."""
|
||||
|
||||
request = urllib_request.Request(
|
||||
self.build_url("/models"),
|
||||
headers=self.prepare_headers(include_auth=True),
|
||||
method="GET",
|
||||
)
|
||||
try:
|
||||
with urllib_request.urlopen(request) as response:
|
||||
payload = self.parse_response_body(response.read())
|
||||
except urllib_error.HTTPError as exc:
|
||||
if exc.code in (401, 403):
|
||||
message = (
|
||||
f"Could not authenticate with {self.provider}."
|
||||
"The provided API key was either missing or invalid.\n"
|
||||
f"Set RAMALAMA_API_KEY or ramalama.provider.<provider_name>.api_key."
|
||||
)
|
||||
try:
|
||||
payload = self.parse_response_body(exc.read())
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
if details := payload.get("error", {}).get("message", None):
|
||||
message = f"{message}\n\n{details}"
|
||||
|
||||
raise ChatProviderError(message, status_code=exc.code) from exc
|
||||
raise
|
||||
|
||||
if not isinstance(payload, Mapping):
|
||||
raise ChatProviderError("Invalid model list payload", payload=payload)
|
||||
|
||||
data = payload.get("data")
|
||||
if not isinstance(data, list):
|
||||
raise ChatProviderError("Invalid model list payload", payload=payload)
|
||||
|
||||
models: list[str] = []
|
||||
for entry in data:
|
||||
if isinstance(entry, Mapping) and (model_id := entry.get("id")):
|
||||
models.append(str(model_id))
|
||||
|
||||
return models
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChatProvider",
|
||||
"ChatProviderError",
|
||||
"ChatRequestOptions",
|
||||
"ChatStreamEvent",
|
||||
]
|
||||
338
ramalama/chat_providers/openai.py
Normal file
338
ramalama/chat_providers/openai.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import json
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import singledispatch
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from ramalama.chat_providers.base import ChatProvider, ChatRequestOptions, ChatStreamEvent
|
||||
from ramalama.chat_utils import (
|
||||
AssistantMessage,
|
||||
AttachmentPart,
|
||||
ChatMessageType,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
serialize_part,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedMessageType(Exception):
|
||||
"""Raised when a provider request fails or returns an invalid payload."""
|
||||
|
||||
|
||||
@singledispatch
|
||||
def message_to_completions_dict(message: Any) -> dict[str, Any]:
|
||||
message = (
|
||||
f"Cannot convert message type `{type(message)}` to a completions dictionary.\n"
|
||||
"Please create an issue at: https://github.com/containers/ramalama/issues"
|
||||
)
|
||||
raise UnsupportedMessageType(message)
|
||||
|
||||
|
||||
@message_to_completions_dict.register
|
||||
def _(message: SystemMessage) -> dict[str, Any]:
|
||||
return {**message.metadata, 'content': message.text or "", 'role': message.role}
|
||||
|
||||
|
||||
@message_to_completions_dict.register
|
||||
def _(message: ToolMessage) -> dict[str, Any]:
|
||||
response = {
|
||||
**message.metadata,
|
||||
'content': message.text or "",
|
||||
'role': message.role,
|
||||
}
|
||||
if message.tool_call_id:
|
||||
response['tool_call_id'] = message.tool_call_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@message_to_completions_dict.register
|
||||
def _(message: UserMessage) -> dict[str, Any]:
|
||||
if message.attachments:
|
||||
raise ValueError("Attachments are not supported by this provider.")
|
||||
return {**message.metadata, 'content': message.text or "", 'role': message.role}
|
||||
|
||||
|
||||
@message_to_completions_dict.register
|
||||
def _(message: AssistantMessage) -> dict[str, Any]:
|
||||
if message.attachments:
|
||||
raise ValueError("Attachments are not supported by this provider.")
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call.name,
|
||||
"arguments": json.dumps(call.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
for call in message.tool_calls
|
||||
]
|
||||
return {**message.metadata, 'content': message.text or "", 'role': message.role, 'tool_calls': tool_calls}
|
||||
|
||||
|
||||
class CompletionsPayload(TypedDict, total=False):
|
||||
messages: list[dict[str, Any]]
|
||||
model: str | None
|
||||
temperature: float | None
|
||||
max_tokens: int | None
|
||||
stream: bool
|
||||
|
||||
|
||||
class OpenAICompletionsChatProvider(ChatProvider):
|
||||
provider = "openai"
|
||||
default_path = "/chat/completions"
|
||||
|
||||
def __init__(self, base_url: str, api_key: str | None = None):
|
||||
super().__init__(base_url, api_key)
|
||||
self._stream_buffer: str = ""
|
||||
|
||||
def build_payload(self, messages: Sequence[ChatMessageType], options: ChatRequestOptions) -> CompletionsPayload:
|
||||
payload: CompletionsPayload = {
|
||||
"messages": [message_to_completions_dict(m) for m in messages],
|
||||
"model": options.model,
|
||||
"temperature": options.temperature,
|
||||
"max_tokens": options.max_tokens,
|
||||
"stream": options.stream,
|
||||
}
|
||||
return payload
|
||||
|
||||
def parse_stream_chunk(self, chunk: bytes) -> Iterable[ChatStreamEvent]:
|
||||
events: list[ChatStreamEvent] = []
|
||||
self._stream_buffer += chunk.decode("utf-8")
|
||||
|
||||
while "\n\n" in self._stream_buffer:
|
||||
raw_event, self._stream_buffer = self._stream_buffer.split("\n\n", 1)
|
||||
raw_event = raw_event.strip()
|
||||
if not raw_event:
|
||||
continue
|
||||
for line in raw_event.splitlines():
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
payload = line[len("data:") :].strip()
|
||||
if not payload:
|
||||
continue
|
||||
if payload == "[DONE]":
|
||||
events.append(ChatStreamEvent(done=True))
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if delta := self._extract_delta(parsed):
|
||||
events.append(ChatStreamEvent(text=delta, raw=parsed))
|
||||
|
||||
return events
|
||||
|
||||
def _extract_delta(self, payload: Mapping[str, object]) -> str | None:
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
return None
|
||||
|
||||
choice = choices[0]
|
||||
if not isinstance(choice, Mapping):
|
||||
return None
|
||||
|
||||
delta = choice.get("delta")
|
||||
if not isinstance(delta, Mapping):
|
||||
return None
|
||||
|
||||
content = delta.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for entry in content:
|
||||
if not isinstance(entry, Mapping):
|
||||
continue
|
||||
entry_type = entry.get("type")
|
||||
text_value = entry.get("text")
|
||||
if entry_type in {"text", "output_text"} and isinstance(text_value, str):
|
||||
parts.append(text_value)
|
||||
if parts:
|
||||
return "".join(parts)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@singledispatch
|
||||
def message_to_responses_dict(message: Any) -> dict[str, Any]:
|
||||
raise ValueError(f"Undefined message type {type(message)}")
|
||||
|
||||
|
||||
def create_responses_content(
|
||||
text: str | None, attachments: list[AttachmentPart], content_type: str
|
||||
) -> list[dict[str, Any]] | str:
|
||||
"""
|
||||
TODO: Current structure doesn't correctly reflect document ordering
|
||||
(i.e. the possibility of messages interspersed with content)
|
||||
"""
|
||||
content: list[dict[str, Any]] = []
|
||||
if text:
|
||||
content.append({"type": content_type, "text": text})
|
||||
for attachment in attachments:
|
||||
content.append(serialize_part(attachment))
|
||||
|
||||
return content or ""
|
||||
|
||||
|
||||
@message_to_responses_dict.register
|
||||
def _(message: SystemMessage) -> dict[str, Any]:
|
||||
return {**message.metadata, 'content': message.text or "", 'role': message.role}
|
||||
|
||||
|
||||
@message_to_responses_dict.register
|
||||
def _(message: ToolMessage) -> dict[str, Any]:
|
||||
response = {
|
||||
**message.metadata,
|
||||
'content': message.text or "",
|
||||
'role': message.role,
|
||||
}
|
||||
if message.tool_call_id:
|
||||
response['tool_call_id'] = message.tool_call_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@message_to_responses_dict.register
|
||||
def _(message: UserMessage) -> dict[str, Any]:
|
||||
return {
|
||||
**message.metadata,
|
||||
'content': create_responses_content(message.text, message.attachments, "input_text"),
|
||||
'role': message.role,
|
||||
}
|
||||
|
||||
|
||||
@message_to_responses_dict.register
|
||||
def _(message: AssistantMessage) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
**message.metadata,
|
||||
'content': create_responses_content(message.text, message.attachments, "output_text"),
|
||||
'role': message.role,
|
||||
}
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": call.name,
|
||||
"arguments": json.dumps(call.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
for call in message.tool_calls
|
||||
]
|
||||
if tool_calls:
|
||||
payload['tool_calls'] = tool_calls
|
||||
return payload
|
||||
|
||||
|
||||
class ResponsesPayload(TypedDict, total=False):
|
||||
input: list[dict[str, Any]]
|
||||
model: str
|
||||
temperature: float | None
|
||||
max_completion_tokens: int
|
||||
stream: bool
|
||||
|
||||
|
||||
class OpenAIResponsesChatProvider(ChatProvider):
|
||||
provider = "openai"
|
||||
default_path: str = "/responses"
|
||||
|
||||
def __init__(self, base_url: str, api_key: str | None = None):
|
||||
super().__init__(base_url, api_key)
|
||||
self._stream_buffer: str = ""
|
||||
|
||||
def build_payload(self, messages: Sequence[ChatMessageType], options: ChatRequestOptions) -> ResponsesPayload:
|
||||
if options.model is None:
|
||||
raise ValueError("Chat options require a model value")
|
||||
|
||||
payload: ResponsesPayload = {
|
||||
"input": [message_to_responses_dict(m) for m in messages],
|
||||
"model": options.model,
|
||||
"temperature": options.temperature,
|
||||
"stream": options.stream,
|
||||
}
|
||||
|
||||
if options.max_tokens is not None and options.max_tokens > 0:
|
||||
payload["max_completion_tokens"] = options.max_tokens
|
||||
return payload
|
||||
|
||||
def parse_stream_chunk(self, chunk: bytes) -> Iterable[ChatStreamEvent]:
|
||||
events: list[ChatStreamEvent] = []
|
||||
self._stream_buffer += chunk.decode("utf-8")
|
||||
|
||||
while "\n\n" in self._stream_buffer:
|
||||
raw_event, self._stream_buffer = self._stream_buffer.split("\n\n", 1)
|
||||
raw_event = raw_event.strip()
|
||||
if not raw_event:
|
||||
continue
|
||||
|
||||
event_type = ""
|
||||
data_lines: list[str] = []
|
||||
for line in raw_event.splitlines():
|
||||
if line.startswith("event:"):
|
||||
event_type = line[len("event:") :].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_lines.append(line[len("data:") :].strip())
|
||||
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data:
|
||||
continue
|
||||
|
||||
if data == "[DONE]":
|
||||
events.append(ChatStreamEvent(done=True))
|
||||
continue
|
||||
|
||||
try:
|
||||
payload = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if self._is_completion_event(event_type, payload):
|
||||
events.append(ChatStreamEvent(done=True, raw=payload))
|
||||
continue
|
||||
|
||||
if text := self._extract_responses_delta(event_type, payload):
|
||||
events.append(ChatStreamEvent(text=text, raw=payload))
|
||||
|
||||
return events
|
||||
|
||||
@staticmethod
|
||||
def _is_completion_event(event_type: str, payload: Mapping[str, Any]) -> bool:
|
||||
hinted_type = event_type or (payload.get("type") if isinstance(payload, Mapping) else "")
|
||||
return hinted_type == "response.completed"
|
||||
|
||||
@staticmethod
|
||||
def _extract_responses_delta(event_type: str, payload: Mapping[str, Any]) -> str | None:
|
||||
if not event_type:
|
||||
event_type = payload.get("type", "") if isinstance(payload, Mapping) else ""
|
||||
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = payload.get("delta")
|
||||
if isinstance(delta, Mapping):
|
||||
text = delta.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(delta, str):
|
||||
return delta
|
||||
|
||||
if event_type == "response.output_text.done":
|
||||
output = payload.get("output")
|
||||
if isinstance(output, list) and output:
|
||||
first = output[0]
|
||||
if isinstance(first, Mapping):
|
||||
content = first.get("content")
|
||||
if isinstance(content, list) and content:
|
||||
entry = content[0]
|
||||
if isinstance(entry, Mapping):
|
||||
text = entry.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["OpenAICompletionsChatProvider", "OpenAIResponsesChatProvider"]
|
||||
137
ramalama/chat_utils.py
Normal file
137
ramalama/chat_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import base64
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from ramalama.config import CONFIG
|
||||
from ramalama.console import EMOJI, should_colorize
|
||||
|
||||
RoleType = Literal["system", "user", "assistant", "tool"]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImageURLPart:
|
||||
url: str
|
||||
detail: str | None = None
|
||||
type: Literal["image_url"] = "image_url"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImageBytesPart:
|
||||
data: bytes
|
||||
mime_type: str = "application/octet-stream"
|
||||
type: Literal["image_bytes"] = "image_bytes"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCall:
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
AttachmentPart = ImageURLPart | ImageBytesPart
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SystemMessage:
|
||||
role: Literal["system"] = "system"
|
||||
text: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserMessage:
|
||||
role: Literal["user"] = "user"
|
||||
text: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
attachments: list[AttachmentPart] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AssistantMessage:
|
||||
role: Literal["assistant"] = "assistant"
|
||||
text: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
tool_calls: list[ToolCall] = field(default_factory=list)
|
||||
attachments: list[AttachmentPart] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolMessage:
|
||||
text: str
|
||||
role: Literal["tool"] = "tool"
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
tool_call_id: str | None = None
|
||||
|
||||
|
||||
ChatMessageType = SystemMessage | UserMessage | AssistantMessage | ToolMessage
|
||||
|
||||
|
||||
class StreamParser(Protocol):
|
||||
def parse_stream_chunk(self, chunk: bytes) -> Iterable[Any]: ...
|
||||
|
||||
|
||||
def stream_response(chunks: Iterable[bytes], color: str, provider: StreamParser) -> str:
|
||||
color_default = ""
|
||||
color_yellow = ""
|
||||
if (color == "auto" and should_colorize()) or color == "always":
|
||||
color_default = "\033[0m"
|
||||
color_yellow = "\033[33m"
|
||||
|
||||
print("\r", end="")
|
||||
assistant_response = ""
|
||||
for chunk in chunks:
|
||||
events = provider.parse_stream_chunk(chunk)
|
||||
for event in events:
|
||||
text = getattr(event, "text", None)
|
||||
if not text:
|
||||
continue
|
||||
print(f"{color_yellow}{text}{color_default}", end="", flush=True)
|
||||
assistant_response += text
|
||||
|
||||
print("")
|
||||
return assistant_response
|
||||
|
||||
|
||||
def default_prefix() -> str:
|
||||
if not EMOJI:
|
||||
return "> "
|
||||
|
||||
if CONFIG.prefix:
|
||||
return CONFIG.prefix
|
||||
|
||||
if engine := CONFIG.engine:
|
||||
if os.path.basename(engine) == "podman":
|
||||
return "🦭 > "
|
||||
|
||||
if os.path.basename(engine) == "docker":
|
||||
return "🐋 > "
|
||||
|
||||
return "🦙 > "
|
||||
|
||||
|
||||
def serialize_part(part: AttachmentPart) -> dict[str, Any]:
|
||||
if isinstance(part, ImageURLPart):
|
||||
payload: dict[str, Any] = {"url": part.url}
|
||||
if part.detail:
|
||||
payload["detail"] = part.detail
|
||||
return {"type": "image_url", "image_url": payload}
|
||||
if isinstance(part, ImageBytesPart):
|
||||
return {
|
||||
"type": "image_bytes",
|
||||
"image_bytes": {"data": base64.b64encode(part.data).decode("ascii"), "mime_type": part.mime_type},
|
||||
}
|
||||
|
||||
raise TypeError(f"Unsupported message part: {part!r}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolCall",
|
||||
"ImageURLPart",
|
||||
"ImageBytesPart",
|
||||
"default_prefix",
|
||||
"stream_response",
|
||||
"serialize_part",
|
||||
]
|
||||
@@ -24,7 +24,7 @@ except Exception:
|
||||
import ramalama.chat as chat
|
||||
from ramalama import engine
|
||||
from ramalama.arg_types import DefaultArgsType
|
||||
from ramalama.chat import default_prefix
|
||||
from ramalama.chat_utils import default_prefix
|
||||
from ramalama.cli_arg_normalization import normalize_pull_arg
|
||||
from ramalama.command.factory import assemble_command
|
||||
from ramalama.common import accel_image, get_accel, perror
|
||||
@@ -48,6 +48,7 @@ from ramalama.path_utils import file_uri_to_path
|
||||
from ramalama.rag import INPUT_DIR, Rag, RagTransport, rag_image
|
||||
from ramalama.shortnames import Shortnames
|
||||
from ramalama.stack import Stack
|
||||
from ramalama.transports.api import APITransport
|
||||
from ramalama.transports.base import (
|
||||
MODEL_TYPES,
|
||||
NoGGUFModelFileFound,
|
||||
@@ -1033,7 +1034,8 @@ If GPU device on host is accessible to via group access, this option leaks the u
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
default=CONFIG.temp,
|
||||
type=float,
|
||||
default=float(CONFIG.temp),
|
||||
help="temperature of the response from the AI model",
|
||||
completer=suppressCompleter,
|
||||
)
|
||||
@@ -1118,6 +1120,19 @@ def chat_parser(subparsers):
|
||||
parser.add_argument("--url", type=str, default="http://127.0.0.1:8080/v1", help="the url to send requests to")
|
||||
parser.add_argument("--model", "-m", type=str, completer=local_models, help="model for inferencing")
|
||||
parser.add_argument("--rag", type=str, help="a file or directory to use as context for the chat")
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
dest="max_tokens",
|
||||
type=int,
|
||||
default=CONFIG.max_tokens,
|
||||
help="maximum number of tokens to generate (0 = unlimited)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
type=float,
|
||||
default=float(CONFIG.temp),
|
||||
help="temperature of the response from the AI model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"ARGS", nargs="*", help="overrides the default prompt, and the output is returned without entering the chatbot"
|
||||
)
|
||||
@@ -1165,7 +1180,6 @@ def run_cli(args):
|
||||
try:
|
||||
# detect available port and update arguments
|
||||
args.port = compute_serving_port(args)
|
||||
|
||||
model = New(args.MODEL, args)
|
||||
model.ensure_model_exists(args)
|
||||
except KeyError as e:
|
||||
@@ -1177,6 +1191,11 @@ def run_cli(args):
|
||||
except Exception as exc:
|
||||
raise e from exc
|
||||
|
||||
is_api_transport = isinstance(model, APITransport)
|
||||
|
||||
if args.rag and is_api_transport:
|
||||
raise ValueError("ramalama run --rag is not supported for hosted API transports.")
|
||||
|
||||
if args.rag:
|
||||
if not args.container:
|
||||
raise ValueError("ramalama run --rag cannot be run with the --nocontainer option.")
|
||||
@@ -1184,7 +1203,9 @@ def run_cli(args):
|
||||
model = RagTransport(model, assemble_command(args.model_args), args)
|
||||
model.ensure_model_exists(args)
|
||||
|
||||
model.run(args, assemble_command(args))
|
||||
server_cmd = [] if isinstance(model, APITransport) else assemble_command(args)
|
||||
|
||||
model.run(args, server_cmd)
|
||||
|
||||
|
||||
def serve_parser(subparsers):
|
||||
@@ -1224,6 +1245,9 @@ def serve_cli(args):
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
if isinstance(model, APITransport):
|
||||
raise ValueError("ramalama serve is not supported for hosted API transports.")
|
||||
|
||||
if args.rag:
|
||||
if not args.container:
|
||||
raise ValueError("ramalama serve --rag cannot be run with the --nocontainer option.")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""ramalama common module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
|
||||
@@ -144,6 +144,16 @@ class UserConfig:
|
||||
self.no_missing_gpu_prompt = coerce_to_bool(self.no_missing_gpu_prompt)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenaiProviderConfig:
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
openai: OpenaiProviderConfig = field(default_factory=OpenaiProviderConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RamalamaSettings:
|
||||
"""These settings are not managed directly by the user"""
|
||||
@@ -253,6 +263,7 @@ class BaseConfig:
|
||||
gguf_quantization_mode: GGUF_QUANTIZATION_MODES = DEFAULT_GGUF_QUANTIZATION_MODE
|
||||
http_client: HTTPClientConfig = field(default_factory=HTTPClientConfig)
|
||||
log_level: LogLevel | None = None
|
||||
provider: ProviderConfig = field(default_factory=ProviderConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
self.container = coerce_to_bool(self.container) if self.container is not None else self.engine is not None
|
||||
|
||||
@@ -4,6 +4,7 @@ from string import Template
|
||||
from typing import Type
|
||||
from warnings import warn
|
||||
|
||||
from ramalama.chat_utils import AttachmentPart, ChatMessageType, ImageURLPart, UserMessage
|
||||
from ramalama.file_loaders.file_types import base, image, txt
|
||||
|
||||
|
||||
@@ -115,17 +116,18 @@ class OpanAIChatAPIMessageBuilder:
|
||||
def supported_extensions(self) -> set[str]:
|
||||
return self.text_manager.loaders.keys() | self.image_manager.loaders.keys()
|
||||
|
||||
def load(self, file_path: str) -> list[dict]:
|
||||
def load(self, file_path: str) -> list[ChatMessageType]:
|
||||
text_files, image_files, unsupported_files = self.partition_files(file_path)
|
||||
|
||||
if unsupported_files:
|
||||
unsupported_files_warning(unsupported_files, list(self.supported_extensions()))
|
||||
|
||||
messages: list[dict] = []
|
||||
messages: list[ChatMessageType] = []
|
||||
if text_files:
|
||||
messages.append({"role": "system", "content": self.text_manager.load(text_files)})
|
||||
messages.append(UserMessage(text=self.text_manager.load(text_files)))
|
||||
if image_files:
|
||||
content = [{"type": "image_url", "image_url": {"url": c}} for c in self.image_manager.load(image_files)]
|
||||
message = {"role": "system", "content": content}
|
||||
messages.append(message)
|
||||
attachments: list[AttachmentPart] = []
|
||||
for data_url in self.image_manager.load(image_files):
|
||||
attachments.append(ImageURLPart(url=data_url))
|
||||
messages.append(UserMessage(attachments=attachments))
|
||||
return messages
|
||||
|
||||
@@ -124,7 +124,7 @@ class Stack:
|
||||
'--ctx-size',
|
||||
str(self.args.context),
|
||||
'--temp',
|
||||
self.args.temp,
|
||||
str(self.args.temp),
|
||||
'--jinja',
|
||||
'--cache-reuse',
|
||||
'256',
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
from .huggingface import Huggingface, HuggingfaceRepository
|
||||
from .modelscope import ModelScope, ModelScopeRepository
|
||||
from .oci import OCI
|
||||
from .ollama import Ollama, OllamaRepository
|
||||
from .rlcr import RamalamaContainerRegistry
|
||||
from .url import URL
|
||||
from ramalama.transports import api, huggingface, modelscope, oci, ollama, rlcr, transport_factory, url
|
||||
from ramalama.transports.api import APITransport
|
||||
from ramalama.transports.huggingface import Huggingface, HuggingfaceRepository
|
||||
from ramalama.transports.modelscope import ModelScope, ModelScopeRepository
|
||||
from ramalama.transports.oci import OCI
|
||||
from ramalama.transports.ollama import Ollama, OllamaRepository
|
||||
from ramalama.transports.rlcr import RamalamaContainerRegistry
|
||||
from ramalama.transports.url import URL
|
||||
|
||||
__all__ = [
|
||||
"api",
|
||||
"huggingface",
|
||||
"oci",
|
||||
"modelscope",
|
||||
"ollama",
|
||||
"rlcr",
|
||||
"transport_factory",
|
||||
"url",
|
||||
"Huggingface",
|
||||
"HuggingfaceRepository",
|
||||
"ModelScope",
|
||||
@@ -15,4 +25,5 @@ __all__ = [
|
||||
"OllamaRepository",
|
||||
"RamalamaContainerRegistry",
|
||||
"URL",
|
||||
"APITransport",
|
||||
]
|
||||
|
||||
108
ramalama/transports/api.py
Normal file
108
ramalama/transports/api.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Any
|
||||
|
||||
from ramalama.chat import chat
|
||||
from ramalama.chat_providers.base import ChatProvider, ChatProviderError
|
||||
from ramalama.common import perror
|
||||
from ramalama.transports.base import TransportBase
|
||||
|
||||
|
||||
class APITransport(TransportBase):
|
||||
"""Transport that proxies chat requests to a hosted API provider."""
|
||||
|
||||
type: str = "api"
|
||||
|
||||
def __init__(self, model: str, provider: ChatProvider):
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
|
||||
self._model_tag = "latest"
|
||||
self._model_name = self.model
|
||||
self.draft_model = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
@property
|
||||
def model_tag(self) -> str:
|
||||
return self._model_tag
|
||||
|
||||
@property
|
||||
def model_organization(self) -> str:
|
||||
return self.provider.provider
|
||||
|
||||
@property
|
||||
def model_type(self) -> str:
|
||||
return self.type
|
||||
|
||||
def _get_entry_model_path(self, use_container: bool, should_generate: bool, dry_run: bool) -> str:
|
||||
raise NotImplementedError(
|
||||
f"{self.model} is provided over a hosted API preventing direct pulling of the model file."
|
||||
)
|
||||
|
||||
def _get_mmproj_path(self, use_container: bool, should_generate: bool, dry_run: bool):
|
||||
return None
|
||||
|
||||
def _get_chat_template_path(self, use_container: bool, should_generate: bool, dry_run: bool):
|
||||
return None
|
||||
|
||||
def remove(self, args):
|
||||
raise NotImplementedError("Hosted API transports do not support removing remote models.")
|
||||
|
||||
def bench(self, args, cmd: list[str]):
|
||||
raise NotImplementedError("bench is not supported for hosted API transports.")
|
||||
|
||||
def run(self, args, server_cmd: list[str]):
|
||||
"""Connect directly to the provider instead of launching a local server."""
|
||||
args.container = False
|
||||
args.engine = None
|
||||
args.model = self.model
|
||||
|
||||
if getattr(args, "url", None):
|
||||
self.provider.base_url = args.url
|
||||
|
||||
if getattr(args, "api_key", None):
|
||||
self.provider.api_key = args.api_key
|
||||
|
||||
chat(args, provider=self.provider)
|
||||
|
||||
def perplexity(self, args, cmd: list[str]):
|
||||
raise NotImplementedError("perplexity is not supported for hosted API transports.")
|
||||
|
||||
def serve(self, args, cmd: list[str]):
|
||||
raise NotImplementedError("Hosted API transports cannot be served locally.")
|
||||
|
||||
def exists(self) -> bool:
|
||||
return True
|
||||
|
||||
def inspect(self, args):
|
||||
return {
|
||||
"provider": self.provider.provider,
|
||||
"model": self.model_name,
|
||||
"base_url": self.provider.base_url,
|
||||
}
|
||||
|
||||
def ensure_model_exists(self, args):
|
||||
args.container = False
|
||||
args.engine = None
|
||||
if not self.provider.api_key:
|
||||
raise ValueError(
|
||||
f'Missing API key for provider "{self.provider.provider}". '
|
||||
"Set RAMALAMA_API_KEY or ramalama.provider.openai.api_key."
|
||||
)
|
||||
try:
|
||||
models = self.provider.list_models()
|
||||
except ChatProviderError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f'Failed to list models for provider "{self.provider.provider}"') from exc
|
||||
|
||||
if self.model not in models:
|
||||
available = ", ".join(models) if models else "none"
|
||||
raise ValueError(
|
||||
f'Model "{self.model}" not available from provider "{self.provider.provider}". '
|
||||
f"Available models: {available}"
|
||||
)
|
||||
|
||||
def pull(self, args: Any):
|
||||
perror(f"{self.model} is provided over a hosted API preventing direct pulling of the model file.")
|
||||
@@ -4,10 +4,12 @@ from typing import TypeAlias
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ramalama.arg_types import StoreArgType
|
||||
from ramalama.chat_providers.api_providers import get_chat_provider
|
||||
from ramalama.common import rm_until_substring
|
||||
from ramalama.config import CONFIG
|
||||
from ramalama.path_utils import file_uri_to_path
|
||||
from ramalama.transports.base import MODEL_TYPES
|
||||
from ramalama.transports.api import APITransport
|
||||
from ramalama.transports.base import MODEL_TYPES, Transport
|
||||
from ramalama.transports.huggingface import Huggingface
|
||||
from ramalama.transports.modelscope import ModelScope
|
||||
from ramalama.transports.oci import OCI
|
||||
@@ -15,7 +17,7 @@ from ramalama.transports.ollama import Ollama
|
||||
from ramalama.transports.rlcr import RamalamaContainerRegistry
|
||||
from ramalama.transports.url import URL
|
||||
|
||||
CLASS_MODEL_TYPES: TypeAlias = Huggingface | Ollama | OCI | URL | ModelScope | RamalamaContainerRegistry
|
||||
CLASS_MODEL_TYPES: TypeAlias = Huggingface | Ollama | OCI | URL | ModelScope | RamalamaContainerRegistry | APITransport
|
||||
|
||||
|
||||
class TransportFactory:
|
||||
@@ -26,6 +28,7 @@ class TransportFactory:
|
||||
transport: str = "ollama",
|
||||
ignore_stderr: bool = False,
|
||||
):
|
||||
|
||||
self.model = model
|
||||
self.store_path = args.store
|
||||
self.transport = transport
|
||||
@@ -39,41 +42,45 @@ class TransportFactory:
|
||||
self._create = _create
|
||||
|
||||
self.pruned_model = self.prune_model_input()
|
||||
self.draft_model = None
|
||||
self.draft_model: Transport | None = None
|
||||
|
||||
if getattr(args, 'model_draft', None):
|
||||
model_draft = getattr(args, "model_draft", None)
|
||||
if model_draft:
|
||||
dm_args = copy.deepcopy(args)
|
||||
dm_args.model_draft = None # type: ignore
|
||||
self.draft_model = TransportFactory(args.model_draft, dm_args, ignore_stderr=True).create() # type: ignore
|
||||
draft_model = TransportFactory(model_draft, dm_args, ignore_stderr=True).create()
|
||||
if not isinstance(draft_model, Transport):
|
||||
raise ValueError("Draft models must be local transports; hosted API transports are not supported.")
|
||||
self.draft_model = draft_model
|
||||
|
||||
def detect_model_model_type(self) -> tuple[type[CLASS_MODEL_TYPES], Callable[[], CLASS_MODEL_TYPES]]:
|
||||
for prefix in ["huggingface://", "hf://", "hf.co/"]:
|
||||
if self.model.startswith(prefix):
|
||||
match self.model:
|
||||
case model if model.startswith(("huggingface://", "hf://", "hf.co/")):
|
||||
return Huggingface, self.create_huggingface
|
||||
for prefix in ["modelscope://", "ms://"]:
|
||||
if self.model.startswith(prefix):
|
||||
case model if model.startswith(("modelscope://", "ms://")):
|
||||
return ModelScope, self.create_modelscope
|
||||
for prefix in ["ollama://", "ollama.com/library/"]:
|
||||
if self.model.startswith(prefix):
|
||||
case model if model.startswith(("ollama://", "ollama.com/library/")):
|
||||
return Ollama, self.create_ollama
|
||||
for prefix in ["oci://", "docker://"]:
|
||||
if self.model.startswith(prefix):
|
||||
case model if model.startswith(("oci://", "docker://")):
|
||||
return OCI, self.create_oci
|
||||
if self.model.startswith("rlcr://"):
|
||||
return RamalamaContainerRegistry, self.create_rlcr
|
||||
for prefix in ["http://", "https://", "file:"]:
|
||||
if self.model.startswith(prefix):
|
||||
case model if model.startswith("rlcr://"):
|
||||
return RamalamaContainerRegistry, self.create_rlcr
|
||||
case model if model.startswith(("http://", "https://", "file:")):
|
||||
return URL, self.create_url
|
||||
if self.transport == "huggingface":
|
||||
return Huggingface, self.create_huggingface
|
||||
if self.transport == "modelscope":
|
||||
return ModelScope, self.create_modelscope
|
||||
if self.transport == "ollama":
|
||||
return Ollama, self.create_ollama
|
||||
if self.transport == "rlcr":
|
||||
return RamalamaContainerRegistry, self.create_rlcr
|
||||
if self.transport == "oci":
|
||||
return OCI, self.create_oci
|
||||
case model if model.startswith(("openai://")):
|
||||
return APITransport, self.create_api_transport
|
||||
|
||||
match self.transport:
|
||||
case "huggingface":
|
||||
return Huggingface, self.create_huggingface
|
||||
case "modelscope":
|
||||
return ModelScope, self.create_modelscope
|
||||
case "ollama":
|
||||
return Ollama, self.create_ollama
|
||||
case "rlcr":
|
||||
return RamalamaContainerRegistry, self.create_rlcr
|
||||
case "oci":
|
||||
return OCI, self.create_oci
|
||||
|
||||
raise KeyError(f'transport "{self.transport}" not supported. Must be oci, huggingface, modelscope, or ollama.')
|
||||
|
||||
@@ -155,6 +162,10 @@ class TransportFactory:
|
||||
model.draft_model = self.draft_model
|
||||
return model
|
||||
|
||||
def create_api_transport(self) -> APITransport:
|
||||
scheme = self.model.split("://", 1)[0]
|
||||
return APITransport(self.pruned_model, provider=get_chat_provider(scheme))
|
||||
|
||||
|
||||
def New(name, args, transport: str | None = None) -> CLASS_MODEL_TYPES:
|
||||
if transport is None:
|
||||
|
||||
@@ -61,3 +61,4 @@
|
||||
"smolvlm:500m" = "hf://ggml-org/SmolVLM-500M-Instruct-GGUF"
|
||||
"tiny" = "hf://TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
||||
"tinyllama" = "hf://TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
||||
"gpt-5.1" = "openai://gpt-5.1-2025-11-13"
|
||||
|
||||
1
test/unit/providers/__init__.py
Normal file
1
test/unit/providers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Package marker for provider-specific tests.
|
||||
146
test/unit/providers/test_openai_provider.py
Normal file
146
test/unit/providers/test_openai_provider.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from ramalama.chat_providers.base import ChatRequestOptions
|
||||
from ramalama.chat_providers.openai import OpenAICompletionsChatProvider, OpenAIResponsesChatProvider
|
||||
from ramalama.chat_utils import AssistantMessage, ImageURLPart, ToolCall, ToolMessage, UserMessage
|
||||
|
||||
|
||||
def build_payload(content):
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def make_options(**overrides):
|
||||
data = {"model": "test-model", "stream": True}
|
||||
data.update(overrides)
|
||||
return ChatRequestOptions(**data)
|
||||
|
||||
|
||||
class OpenAICompletionsProviderTests:
|
||||
def setup_method(self):
|
||||
self.provider = OpenAICompletionsChatProvider("http://example.com")
|
||||
|
||||
def test_extracts_string_and_structured_deltas(self):
|
||||
assert self.provider._extract_delta(build_payload("hello")) == "hello"
|
||||
|
||||
structured = build_payload(
|
||||
[
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "output_text", "text": " world"},
|
||||
]
|
||||
)
|
||||
assert self.provider._extract_delta(structured) == "hello world"
|
||||
|
||||
def test_streaming_handles_structured_chunks(self):
|
||||
chunk = (
|
||||
b"data: "
|
||||
+ json.dumps(
|
||||
build_payload(
|
||||
[
|
||||
{"type": "output_text", "text": "Hi"},
|
||||
{"type": "output_text", "text": " there"},
|
||||
]
|
||||
)
|
||||
).encode("utf-8")
|
||||
+ b"\n\n"
|
||||
)
|
||||
|
||||
events = list(self.provider.parse_stream_chunk(chunk))
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].text == "Hi there"
|
||||
|
||||
def test_rejects_attachments(self):
|
||||
message = UserMessage(attachments=[ImageURLPart(url="http://img")])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
self.provider.build_payload([message], make_options())
|
||||
|
||||
def test_serializes_tool_calls_and_responses(self):
|
||||
tool_call = ToolCall(id="call-1", name="lookup", arguments={"query": "weather"})
|
||||
assistant = AssistantMessage(tool_calls=[tool_call])
|
||||
tool_reply = ToolMessage(text="72F and sunny", tool_call_id="call-1")
|
||||
|
||||
payload = self.provider.build_payload([assistant, tool_reply], make_options())
|
||||
messages = payload["messages"]
|
||||
|
||||
assert messages[0]["tool_calls"][0]["function"]["name"] == "lookup"
|
||||
assert messages[0]["tool_calls"][0]["function"]["arguments"] == '{"query": "weather"}'
|
||||
assert messages[1]["tool_call_id"] == "call-1"
|
||||
|
||||
|
||||
class OpenAIResponsesProviderTests:
|
||||
def setup_method(self):
|
||||
self.provider = OpenAIResponsesChatProvider("http://example.com")
|
||||
|
||||
def test_serializes_structured_content(self):
|
||||
message = UserMessage(
|
||||
text="hello",
|
||||
attachments=[ImageURLPart(url="http://img", detail="high")],
|
||||
)
|
||||
|
||||
payload = self.provider.build_payload([message], make_options(max_tokens=128))
|
||||
serialized = payload["input"][0]["content"]
|
||||
|
||||
assert serialized[0] == {"type": "input_text", "text": "hello"}
|
||||
assert serialized[1]["type"] == "image_url"
|
||||
assert serialized[1]["image_url"] == {"url": "http://img", "detail": "high"}
|
||||
assert payload["max_completion_tokens"] == 128
|
||||
assert "max_tokens" not in payload
|
||||
|
||||
def test_streaming_emits_delta_and_completion_events(self):
|
||||
chunk = (
|
||||
b"event: response.output_text.delta\n"
|
||||
b'data: {"type":"response.output_text.delta","delta":{"text":"Hi"}}\n\n'
|
||||
b"event: response.completed\n"
|
||||
b'data: {"type":"response.completed"}\n\n'
|
||||
)
|
||||
|
||||
events = list(self.provider.parse_stream_chunk(chunk))
|
||||
|
||||
assert events[0].text == "Hi"
|
||||
assert events[1].done is True
|
||||
|
||||
def test_serializes_tool_calls(self):
|
||||
tool_call = ToolCall(id="call-9", name="lookup", arguments={"city": "NYC"})
|
||||
assistant = AssistantMessage(tool_calls=[tool_call])
|
||||
tool_reply = ToolMessage(text="Clear skies", tool_call_id="call-9")
|
||||
|
||||
payload = self.provider.build_payload([assistant, tool_reply], make_options())
|
||||
first_input = payload["input"][0]
|
||||
|
||||
assert first_input["tool_calls"][0]["function"]["name"] == "lookup"
|
||||
assert first_input["tool_calls"][0]["function"]["arguments"] == '{"city": "NYC"}'
|
||||
assert payload["input"][1]["tool_call_id"] == "call-9"
|
||||
|
||||
def test_streaming_emits_done_event_for_done_marker(self):
|
||||
events = list(self.provider.parse_stream_chunk(b"data: [DONE]\n\n"))
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].done is True
|
||||
|
||||
def test_streaming_ignores_invalid_json_chunks(self):
|
||||
events = list(self.provider.parse_stream_chunk(b"data: {invalid-json\n\n"))
|
||||
|
||||
assert events == []
|
||||
|
||||
def test_streaming_extracts_text_from_done_events(self):
|
||||
chunk = (
|
||||
b"event: response.output_text.done\n"
|
||||
b'data: {"type":"response.output_text.done","output":[{"content":'
|
||||
b'[{"type":"output_text","text":"All done"}]}]}\n\n'
|
||||
)
|
||||
|
||||
events = list(self.provider.parse_stream_chunk(chunk))
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].text == "All done"
|
||||
17
test/unit/test_api_providers.py
Normal file
17
test/unit/test_api_providers.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
|
||||
from ramalama.chat_providers.api_providers import get_chat_provider
|
||||
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
|
||||
|
||||
|
||||
def test_get_chat_provider_returns_openai_provider():
|
||||
provider = get_chat_provider("openai")
|
||||
|
||||
assert isinstance(provider, OpenAIResponsesChatProvider)
|
||||
assert provider.base_url == "https://api.openai.com/v1"
|
||||
assert provider.provider == "openai"
|
||||
|
||||
|
||||
def test_get_chat_provider_raises_for_unknown_scheme():
|
||||
with pytest.raises(ValueError):
|
||||
get_chat_provider("anthropic")
|
||||
89
test/unit/test_api_transport.py
Normal file
89
test/unit/test_api_transport.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
|
||||
from ramalama.config import CONFIG
|
||||
from ramalama.transports import api as api_module
|
||||
from ramalama.transports.api import APITransport
|
||||
|
||||
|
||||
def make_provider(api_key: str = "provider-default") -> OpenAIResponsesChatProvider:
|
||||
return OpenAIResponsesChatProvider("https://api.openai.com/v1", api_key=api_key)
|
||||
|
||||
|
||||
def test_api_transport_run(monkeypatch):
|
||||
provider = make_provider()
|
||||
transport = APITransport("gpt-4o-mini", provider)
|
||||
recorded: dict[str, object] = {}
|
||||
|
||||
def fake_chat(args, operational_args=None, provider=None):
|
||||
recorded["args"] = args
|
||||
recorded["operational_args"] = operational_args
|
||||
recorded["provider"] = provider
|
||||
|
||||
monkeypatch.setattr(api_module, "chat", fake_chat)
|
||||
|
||||
args = SimpleNamespace(
|
||||
container=True, engine="podman", url="http://localhost", model=None, api="none", api_key=None
|
||||
)
|
||||
transport.run(args, [])
|
||||
|
||||
assert args.container is False
|
||||
assert args.engine is None
|
||||
assert args.url == provider.base_url
|
||||
assert args.model == "gpt-4o-mini"
|
||||
assert recorded["args"] is args
|
||||
assert recorded["provider"] is provider
|
||||
assert provider.base_url == "http://localhost"
|
||||
assert provider.api_key == "provider-default"
|
||||
|
||||
|
||||
def test_api_transport_ensure_exists_mutates_args(monkeypatch):
|
||||
provider = make_provider()
|
||||
transport = APITransport("gpt-4", provider)
|
||||
args = SimpleNamespace(container=True, engine="podman")
|
||||
monkeypatch.setattr(provider, "list_models", lambda: ["gpt-4", "other"])
|
||||
|
||||
transport.ensure_model_exists(args)
|
||||
|
||||
assert args.container is False
|
||||
assert args.engine is None
|
||||
|
||||
|
||||
def test_api_transport_ensure_exists_requires_api_key(monkeypatch):
|
||||
monkeypatch.setattr(CONFIG, "api_key", None)
|
||||
provider = make_provider(api_key=None)
|
||||
transport = APITransport("gpt-4", provider)
|
||||
args = SimpleNamespace(container=True, engine="podman")
|
||||
|
||||
with pytest.raises(ValueError, match="Missing API key"):
|
||||
transport.ensure_model_exists(args)
|
||||
|
||||
|
||||
def test_api_transport_overrides_provider_api_key(monkeypatch):
|
||||
provider = make_provider()
|
||||
transport = APITransport("gpt-4o-mini", provider)
|
||||
|
||||
recorded: dict[str, object] = {}
|
||||
|
||||
def fake_chat(args, operational_args=None, provider=None):
|
||||
recorded["provider"] = provider
|
||||
|
||||
monkeypatch.setattr(api_module, "chat", fake_chat)
|
||||
|
||||
args = SimpleNamespace(container=True, engine="podman", url=None, model=None, api="none", api_key="cli-secret")
|
||||
transport.run(args, [])
|
||||
|
||||
assert provider.api_key == "cli-secret"
|
||||
assert recorded["provider"] is provider
|
||||
|
||||
|
||||
def test_api_transport_ensure_exists_raises_if_model_missing(monkeypatch):
|
||||
provider = make_provider()
|
||||
transport = APITransport("gpt-4", provider)
|
||||
monkeypatch.setattr(provider, "list_models", lambda: ["gpt-3.5"])
|
||||
args = SimpleNamespace(container=True, engine="podman")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
transport.ensure_model_exists(args)
|
||||
33
test/unit/test_chat_provider_base.py
Normal file
33
test/unit/test_chat_provider_base.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import io
|
||||
import urllib.error
|
||||
|
||||
import pytest
|
||||
|
||||
import ramalama.chat_providers.base as base_module
|
||||
from ramalama.chat_providers.base import ChatProviderError
|
||||
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
|
||||
|
||||
|
||||
def test_list_models_reports_auth_error(monkeypatch):
|
||||
provider = OpenAIResponsesChatProvider("https://api.openai.com/v1", api_key="bad")
|
||||
error_body = b'{"error":{"message":"Invalid API key"}}'
|
||||
http_error = urllib.error.HTTPError(
|
||||
provider.build_url("/models"),
|
||||
401,
|
||||
"Unauthorized",
|
||||
{},
|
||||
io.BytesIO(error_body),
|
||||
)
|
||||
|
||||
def fake_urlopen(request):
|
||||
raise http_error
|
||||
|
||||
monkeypatch.setattr(base_module.urllib_request, "urlopen", fake_urlopen)
|
||||
|
||||
with pytest.raises(ChatProviderError) as excinfo:
|
||||
provider.list_models()
|
||||
|
||||
message = str(excinfo.value)
|
||||
assert "Could not authenticate with openai." in message
|
||||
assert "missing or invalid" in message
|
||||
assert "Invalid API key" in message
|
||||
@@ -51,6 +51,7 @@ parser = get_parser()
|
||||
|
||||
special_cases = {
|
||||
"api_key": "api-key",
|
||||
"max_tokens": "max-tokens",
|
||||
}
|
||||
|
||||
|
||||
@@ -115,6 +116,10 @@ def test_default_endpoint(chatargs):
|
||||
ChatSubArgs,
|
||||
prefix=st.sampled_from(['> ', '🦙 > ', '🦭 > ', '🐋 > ']),
|
||||
url=st.sampled_from(['https://test.com', 'test.com']),
|
||||
temp=st.one_of(
|
||||
st.none(),
|
||||
st.floats(min_value=0, allow_nan=False, allow_infinity=False).map(lambda v: 0.0 if v == 0 else v),
|
||||
),
|
||||
)
|
||||
)
|
||||
def test_chat_endpoint(chatargs):
|
||||
|
||||
@@ -4,12 +4,21 @@ from unittest.mock import mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ramalama.chat_utils import ImageURLPart
|
||||
from ramalama.file_loaders.file_manager import ImageFileManager, OpanAIChatAPIMessageBuilder, TextFileManager
|
||||
from ramalama.file_loaders.file_types.base import BaseFileLoader
|
||||
from ramalama.file_loaders.file_types.image import BasicImageFileLoader
|
||||
from ramalama.file_loaders.file_types.txt import TXTFileLoader
|
||||
|
||||
|
||||
def _text_content(message):
|
||||
return message.text or ""
|
||||
|
||||
|
||||
def _image_parts(message):
|
||||
return [attachment for attachment in message.attachments if isinstance(attachment, ImageURLPart)]
|
||||
|
||||
|
||||
class TestBaseFileLoader:
|
||||
"""Test the abstract base class for file upload handlers."""
|
||||
|
||||
@@ -269,9 +278,10 @@ class TestOpanAIChatAPIMessageBuilder:
|
||||
messages = builder.load(tmp_file.name)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Test content" in messages[0]["content"]
|
||||
assert f"<!--start_document {tmp_file.name}-->" in messages[0]["content"]
|
||||
assert messages[0].role == "user"
|
||||
content = _text_content(messages[0])
|
||||
assert "Test content" in content
|
||||
assert f"<!--start_document {tmp_file.name}-->" in content
|
||||
|
||||
def test_builder_load_image_files_only(self):
|
||||
"""Test loading only image files."""
|
||||
@@ -283,12 +293,10 @@ class TestOpanAIChatAPIMessageBuilder:
|
||||
messages = builder.load(tmp_file.name)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "system"
|
||||
assert isinstance(messages[0]["content"], list)
|
||||
assert len(messages[0]["content"]) == 1
|
||||
assert 'image_url' in messages[0]["content"][0]
|
||||
assert 'url' in messages[0]["content"][0]["image_url"]
|
||||
assert "data:image/" in messages[0]["content"][0]["image_url"]["url"]
|
||||
assert messages[0].role == "user"
|
||||
image_parts = _image_parts(messages[0])
|
||||
assert len(image_parts) == 1
|
||||
assert "data:image/" in image_parts[0].url
|
||||
|
||||
def test_builder_load_mixed_files(self):
|
||||
"""Test loading mixed text and image files."""
|
||||
@@ -306,12 +314,11 @@ class TestOpanAIChatAPIMessageBuilder:
|
||||
|
||||
assert len(messages) == 2
|
||||
# First message should be text
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "Text content" in messages[0]["content"]
|
||||
assert messages[0].role == "user"
|
||||
assert "Text content" in _text_content(messages[0])
|
||||
# Second message should be image
|
||||
assert messages[1]["role"] == "system"
|
||||
assert isinstance(messages[1]["content"], list)
|
||||
assert len(messages[1]["content"]) == 1
|
||||
assert messages[1].role == "user"
|
||||
assert len(_image_parts(messages[1])) == 1
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
|
||||
def test_builder_load_no_supported_files(self):
|
||||
@@ -405,7 +412,7 @@ class TestFileUploadIntegration:
|
||||
messages = builder.load(tmp_dir)
|
||||
|
||||
assert len(messages) == 1
|
||||
content = messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
for file_content in files_content.values():
|
||||
assert file_content in content
|
||||
|
||||
|
||||
@@ -5,6 +5,11 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from ramalama.chat import RamaLamaShell, chat
|
||||
from ramalama.chat_utils import ImageURLPart
|
||||
|
||||
|
||||
def _text_content(message):
|
||||
return message.text or ""
|
||||
|
||||
|
||||
class TestFileUploadChatIntegration:
|
||||
@@ -31,10 +36,11 @@ class TestFileUploadChatIntegration:
|
||||
|
||||
# Check that the system message was added to conversation history
|
||||
assert len(shell.conversation_history) == 1
|
||||
system_message = shell.conversation_history[0]
|
||||
assert system_message["role"] == "system"
|
||||
assert "This is test content for chat input" in system_message["content"]
|
||||
assert f"<!--start_document {tmp_file.name}-->" in system_message["content"]
|
||||
message = shell.conversation_history[0]
|
||||
assert message.role == "user"
|
||||
content = message.text or ""
|
||||
assert "This is test content for chat input" in content
|
||||
assert f"<!--start_document {tmp_file.name}-->" in content
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_chat_with_file_input_directory(self, mock_urlopen):
|
||||
@@ -62,13 +68,14 @@ class TestFileUploadChatIntegration:
|
||||
|
||||
# Check that the system message was added to conversation history
|
||||
assert len(shell.conversation_history) == 1
|
||||
system_message = shell.conversation_history[0]
|
||||
assert system_message["role"] == "system"
|
||||
assert "Text file content" in system_message["content"]
|
||||
assert "# Markdown Content" in system_message["content"]
|
||||
assert "test.txt" in system_message["content"]
|
||||
assert "readme.md" in system_message["content"]
|
||||
assert "<!--start_document" in system_message["content"]
|
||||
message = shell.conversation_history[0]
|
||||
assert message.role == "user"
|
||||
content = message.text or ""
|
||||
assert "Text file content" in content
|
||||
assert "# Markdown Content" in content
|
||||
assert "test.txt" in content
|
||||
assert "readme.md" in content
|
||||
assert "<!--start_document" in content
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
|
||||
@patch('urllib.request.urlopen')
|
||||
@@ -125,11 +132,11 @@ class TestFileUploadChatIntegration:
|
||||
|
||||
# Check that the system message was added to conversation history
|
||||
assert len(shell.conversation_history) == 1
|
||||
system_message = shell.conversation_history[0]
|
||||
assert system_message["role"] == "system"
|
||||
assert f"<!--start_document {tmp_file.name}-->" in system_message["content"]
|
||||
# Empty file should still have the delimiter but no content
|
||||
assert system_message["content"].endswith(f"\n<!--start_document {tmp_file.name}-->\n")
|
||||
message = shell.conversation_history[0]
|
||||
assert message.role == "user"
|
||||
text = _text_content(message)
|
||||
assert f"<!--start_document {tmp_file.name}-->" in text
|
||||
assert text.endswith(f"\n<!--start_document {tmp_file.name}-->\n")
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_chat_with_file_input_unicode_content(self, mock_urlopen):
|
||||
@@ -153,10 +160,11 @@ class TestFileUploadChatIntegration:
|
||||
|
||||
# Check that the system message was added to conversation history
|
||||
assert len(shell.conversation_history) == 1
|
||||
system_message = shell.conversation_history[0]
|
||||
assert system_message["role"] == "system"
|
||||
assert unicode_content in system_message["content"]
|
||||
assert f"<!--start_document {tmp_file.name}-->" in system_message["content"]
|
||||
message = shell.conversation_history[0]
|
||||
assert message.role == "user"
|
||||
text = message.text or ""
|
||||
assert unicode_content in text
|
||||
assert f"<!--start_document {tmp_file.name}-->" in text
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_chat_with_file_input_mixed_content_types(self, mock_urlopen):
|
||||
@@ -188,15 +196,16 @@ class TestFileUploadChatIntegration:
|
||||
|
||||
# Check that the system message was added to conversation history
|
||||
assert len(shell.conversation_history) == 1
|
||||
system_message = shell.conversation_history[0]
|
||||
assert system_message["role"] == "system"
|
||||
assert "English content" in system_message["content"]
|
||||
assert '{"key": "value", "number": 42}' in system_message["content"]
|
||||
assert "setting: enabled" in system_message["content"]
|
||||
assert "values:" in system_message["content"]
|
||||
assert "english.txt" in system_message["content"]
|
||||
assert "data.json" in system_message["content"]
|
||||
assert "config.yaml" in system_message["content"]
|
||||
message = shell.conversation_history[0]
|
||||
assert message.role == "user"
|
||||
text = _text_content(message)
|
||||
assert "English content" in text
|
||||
assert '{"key": "value", "number": 42}' in text
|
||||
assert "setting: enabled" in text
|
||||
assert "values:" in text
|
||||
assert "english.txt" in text
|
||||
assert "data.json" in text
|
||||
assert "config.yaml" in text
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_chat_with_file_input_no_input_specified(self, mock_urlopen):
|
||||
@@ -237,10 +246,11 @@ class TestFileUploadChatIntegration:
|
||||
|
||||
# Check that the system message was added to conversation history
|
||||
assert len(shell.conversation_history) == 1
|
||||
system_message = shell.conversation_history[0]
|
||||
assert system_message["role"] == "system"
|
||||
assert "File content" in system_message["content"]
|
||||
assert f"<!--start_document {tmp_file.name}-->" in system_message["content"]
|
||||
message = shell.conversation_history[0]
|
||||
assert message.role == "user"
|
||||
text = _text_content(message)
|
||||
assert "File content" in text
|
||||
assert f"<!--start_document {tmp_file.name}-->" in text
|
||||
|
||||
def test_chat_function_with_rag_and_dryrun(self):
|
||||
"""Test that chat function works correctly with rag and dryrun."""
|
||||
@@ -290,14 +300,13 @@ class TestImageUploadChatIntegration:
|
||||
|
||||
# Check that the system message was added to conversation history
|
||||
assert len(shell.conversation_history) == 1
|
||||
system_message = shell.conversation_history[0]
|
||||
assert system_message["role"] == "system"
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert len(system_message["content"]) == 1
|
||||
assert 'image_url' in system_message["content"][0]
|
||||
assert 'url' in system_message["content"][0]["image_url"]
|
||||
assert "data:image/" in system_message["content"][0]["image_url"]["url"]
|
||||
assert "base64," in system_message["content"][0]["image_url"]["url"]
|
||||
message = shell.conversation_history[0]
|
||||
assert message.role == "user"
|
||||
assert len(message.attachments) == 1
|
||||
part = message.attachments[0]
|
||||
assert isinstance(part, ImageURLPart)
|
||||
assert part.url.startswith("data:image/")
|
||||
assert "base64," in part.url
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_chat_with_image_input_directory(self, mock_urlopen):
|
||||
@@ -325,14 +334,13 @@ class TestImageUploadChatIntegration:
|
||||
|
||||
# Check that the system message was added to conversation history
|
||||
assert len(shell.conversation_history) == 1
|
||||
system_message = shell.conversation_history[0]
|
||||
assert system_message["role"] == "system"
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert len(system_message["content"]) == 2
|
||||
assert all('image_url' in item for item in system_message["content"])
|
||||
assert all('url' in item["image_url"] for item in system_message["content"])
|
||||
assert all("data:image/" in item["image_url"]["url"] for item in system_message["content"])
|
||||
assert all("base64," in item["image_url"]["url"] for item in system_message["content"])
|
||||
message = shell.conversation_history[0]
|
||||
assert message.role == "user"
|
||||
assert len(message.attachments) == 2
|
||||
for part in message.attachments:
|
||||
assert isinstance(part, ImageURLPart)
|
||||
assert "data:image/" in part.url
|
||||
assert "base64," in part.url
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_chat_with_image_input_mixed_file_types(self, mock_urlopen):
|
||||
@@ -359,30 +367,22 @@ class TestImageUploadChatIntegration:
|
||||
shell = RamaLamaShell(mock_args)
|
||||
|
||||
# Check that two system messages were added to conversation history
|
||||
system_messages = [msg for msg in shell.conversation_history if msg["role"] == "system"]
|
||||
assert len(system_messages) == 2
|
||||
user_messages = [msg for msg in shell.conversation_history if msg.role == "user"]
|
||||
assert len(user_messages) == 2
|
||||
|
||||
# Determine which message is text and which is image
|
||||
if isinstance(system_messages[0]["content"], str):
|
||||
text_msg = system_messages[0]
|
||||
image_msg = system_messages[1]
|
||||
if user_messages[0].attachments:
|
||||
image_msg = user_messages[0]
|
||||
text_msg = user_messages[1]
|
||||
else:
|
||||
text_msg = system_messages[1]
|
||||
image_msg = system_messages[0]
|
||||
text_msg = user_messages[0]
|
||||
image_msg = user_messages[1]
|
||||
|
||||
# Assert text message content
|
||||
assert "Text content" in text_msg["content"]
|
||||
assert "readme.txt" in text_msg["content"]
|
||||
text = _text_content(text_msg)
|
||||
assert "Text content" in text
|
||||
assert "readme.txt" in text
|
||||
|
||||
# Assert image message content
|
||||
assert isinstance(image_msg["content"], list)
|
||||
assert any(
|
||||
isinstance(item, dict)
|
||||
and "image_url" in item
|
||||
and "url" in item["image_url"]
|
||||
and "data:image/" in item["image_url"]["url"]
|
||||
for item in image_msg["content"]
|
||||
)
|
||||
assert any(isinstance(part, ImageURLPart) for part in image_msg.attachments)
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
|
||||
@patch('urllib.request.urlopen')
|
||||
@@ -434,11 +434,10 @@ class TestImageUploadChatIntegration:
|
||||
|
||||
# Check that the system message was added to conversation history
|
||||
assert len(shell.conversation_history) == 1
|
||||
system_message = shell.conversation_history[0]
|
||||
assert system_message["role"] == "system"
|
||||
assert isinstance(system_message["content"], list)
|
||||
assert len(system_message["content"]) == 2
|
||||
assert all('image_url' in item for item in system_message["content"])
|
||||
assert all('url' in item["image_url"] for item in system_message["content"])
|
||||
assert all("data:image/" in item["image_url"]["url"] for item in system_message["content"])
|
||||
assert all("base64," in item["image_url"]["url"] for item in system_message["content"])
|
||||
message = shell.conversation_history[0]
|
||||
assert message.role == "user"
|
||||
assert len(message.attachments) == 2
|
||||
for part in message.attachments:
|
||||
assert isinstance(part, ImageURLPart)
|
||||
assert "data:image/" in part.url
|
||||
assert "base64," in part.url
|
||||
|
||||
@@ -4,9 +4,18 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from ramalama.chat_utils import ImageURLPart
|
||||
from ramalama.file_loaders.file_manager import OpanAIChatAPIMessageBuilder
|
||||
|
||||
|
||||
def _text_content(message):
|
||||
return message.text or ""
|
||||
|
||||
|
||||
def _image_parts(message):
|
||||
return [attachment for attachment in message.attachments if isinstance(attachment, ImageURLPart)]
|
||||
|
||||
|
||||
class TestFileUploadWithDataFiles:
|
||||
"""Test file upload functionality using sample data files."""
|
||||
|
||||
@@ -24,10 +33,11 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(str(txt_file))
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "This is a sample text file" in messages[0]["content"]
|
||||
assert "TXTFileUpload class" in messages[0]["content"]
|
||||
assert "Special characters like: !@#$%^&*()" in messages[0]["content"]
|
||||
assert f"<!--start_document {txt_file}-->" in messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert "This is a sample text file" in content
|
||||
assert "TXTFileUpload class" in content
|
||||
assert "Special characters like: !@#$%^&*()" in content
|
||||
assert f"<!--start_document {txt_file}-->" in content
|
||||
|
||||
def test_load_single_markdown_file(self, data_dir):
|
||||
"""Test loading a single markdown file from the data directory."""
|
||||
@@ -37,11 +47,12 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(str(md_file))
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "# Sample Markdown File" in messages[0]["content"]
|
||||
assert "**Bold text** and *italic text*" in messages[0]["content"]
|
||||
assert "```python" in messages[0]["content"]
|
||||
assert "def hello_world():" in messages[0]["content"]
|
||||
assert f"<!--start_document {md_file}-->" in messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert "# Sample Markdown File" in content
|
||||
assert "**Bold text** and *italic text*" in content
|
||||
assert "```python" in content
|
||||
assert "def hello_world():" in content
|
||||
assert f"<!--start_document {md_file}-->" in content
|
||||
|
||||
def test_load_single_json_file(self, data_dir):
|
||||
"""Test loading a single JSON file from the data directory."""
|
||||
@@ -51,11 +62,12 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(str(json_file))
|
||||
|
||||
assert len(messages) == 1
|
||||
assert '"name": "test_data"' in messages[0]["content"]
|
||||
assert '"version": "1.0.0"' in messages[0]["content"]
|
||||
assert '"text_processing"' in messages[0]["content"]
|
||||
assert '"supported_formats"' in messages[0]["content"]
|
||||
assert f"<!--start_document {json_file}-->" in messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert '"name": "test_data"' in content
|
||||
assert '"version": "1.0.0"' in content
|
||||
assert '"text_processing"' in content
|
||||
assert '"supported_formats"' in content
|
||||
assert f"<!--start_document {json_file}-->" in content
|
||||
|
||||
def test_load_single_yaml_file(self, data_dir):
|
||||
"""Test loading a single YAML file from the data directory."""
|
||||
@@ -65,12 +77,13 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(str(yaml_file))
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "name: test_config" in messages[0]["content"]
|
||||
assert "version: 1.0.0" in messages[0]["content"]
|
||||
assert "- text_processing" in messages[0]["content"]
|
||||
assert "- yaml_support" in messages[0]["content"]
|
||||
assert "deep:" in messages[0]["content"]
|
||||
assert f"<!--start_document {yaml_file}-->" in messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert "name: test_config" in content
|
||||
assert "version: 1.0.0" in content
|
||||
assert "- text_processing" in content
|
||||
assert "- yaml_support" in content
|
||||
assert "deep:" in content
|
||||
assert f"<!--start_document {yaml_file}-->" in content
|
||||
|
||||
def test_load_single_csv_file(self, data_dir):
|
||||
"""Test loading a single CSV file from the data directory."""
|
||||
@@ -80,11 +93,12 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(str(csv_file))
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "name,age,city,occupation" in messages[0]["content"]
|
||||
assert "John Doe,30,New York,Engineer" in messages[0]["content"]
|
||||
assert "Jane Smith,25,San Francisco,Designer" in messages[0]["content"]
|
||||
assert "Bob Johnson,35,Chicago,Manager" in messages[0]["content"]
|
||||
assert f"<!--start_document {csv_file}-->" in messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert "name,age,city,occupation" in content
|
||||
assert "John Doe,30,New York,Engineer" in content
|
||||
assert "Jane Smith,25,San Francisco,Designer" in content
|
||||
assert "Bob Johnson,35,Chicago,Manager" in content
|
||||
assert f"<!--start_document {csv_file}-->" in content
|
||||
|
||||
def test_load_single_toml_file(self, data_dir):
|
||||
"""Test loading a single TOML file from the data directory."""
|
||||
@@ -94,12 +108,13 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(str(toml_file))
|
||||
|
||||
assert len(messages) == 1
|
||||
assert 'name = "test_config"' in messages[0]["content"]
|
||||
assert 'version = "1.0.0"' in messages[0]["content"]
|
||||
assert 'text_processing = true' in messages[0]["content"]
|
||||
assert 'toml_support = true' in messages[0]["content"]
|
||||
assert 'with_deep_nesting = true' in messages[0]["content"]
|
||||
assert f"<!--start_document {toml_file}-->" in messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert 'name = "test_config"' in content
|
||||
assert 'version = "1.0.0"' in content
|
||||
assert 'text_processing = true' in content
|
||||
assert 'toml_support = true' in content
|
||||
assert 'with_deep_nesting = true' in content
|
||||
assert f"<!--start_document {toml_file}-->" in content
|
||||
|
||||
def test_load_single_shell_script(self, data_dir):
|
||||
"""Test loading a single shell script from the data directory."""
|
||||
@@ -109,12 +124,13 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(str(sh_file))
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "#!/bin/bash" in messages[0]["content"]
|
||||
assert "Hello, World! This is a test script." in messages[0]["content"]
|
||||
assert "test_function()" in messages[0]["content"]
|
||||
assert "for i in {1..3}" in messages[0]["content"]
|
||||
assert "Script completed successfully!" in messages[0]["content"]
|
||||
assert f"<!--start_document {sh_file}-->" in messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert "#!/bin/bash" in content
|
||||
assert "Hello, World! This is a test script." in content
|
||||
assert "test_function()" in content
|
||||
assert "for i in {1..3}" in content
|
||||
assert "Script completed successfully!" in content
|
||||
assert f"<!--start_document {sh_file}-->" in content
|
||||
|
||||
def test_load_entire_data_directory(self, data_dir):
|
||||
"""Test loading all files from the data directory."""
|
||||
@@ -122,7 +138,7 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(str(data_dir))
|
||||
|
||||
assert len(messages) == 1
|
||||
content = messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert "This is a sample text file" in content # sample.txt
|
||||
assert "# Sample Markdown File" in content # sample.md
|
||||
assert '"name": "test_data"' in content # sample.json
|
||||
@@ -151,7 +167,7 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(str(txt_file))
|
||||
|
||||
assert len(messages) == 1
|
||||
content = messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
content_start = content.find('\n', content.find('<!--start_document')) + 1
|
||||
extracted_content = content[content_start:]
|
||||
|
||||
@@ -172,7 +188,7 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(tmp_dir)
|
||||
|
||||
assert len(messages) == 1
|
||||
content = messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert "This is a sample text file" in content # sample.txt
|
||||
assert "# Sample Markdown File" in content # sample.md
|
||||
assert '"name": "test_data"' in content # sample.json
|
||||
@@ -200,7 +216,7 @@ class TestFileUploadWithDataFiles:
|
||||
messages = builder.load(tmp_dir)
|
||||
|
||||
assert len(messages) == 1
|
||||
content = messages[0]["content"]
|
||||
content = _text_content(messages[0])
|
||||
assert "This is a sample text file" in content
|
||||
assert "This is an unsupported file type" not in content
|
||||
assert "sample.txt" in content
|
||||
@@ -226,12 +242,10 @@ class TestImageUploadWithDataFiles:
|
||||
messages = builder.load(tmp_file.name)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0]["content"], list)
|
||||
assert len(messages[0]["content"]) == 1
|
||||
assert 'image_url' in messages[0]["content"][0]
|
||||
assert 'url' in messages[0]["content"][0]["image_url"]
|
||||
assert "data:image/" in messages[0]["content"][0]["image_url"]["url"]
|
||||
assert "base64," in messages[0]["content"][0]["image_url"]["url"]
|
||||
image_parts = _image_parts(messages[0])
|
||||
assert len(image_parts) == 1
|
||||
assert "data:image/" in image_parts[0].url
|
||||
assert "base64," in image_parts[0].url
|
||||
|
||||
def test_load_multiple_image_files(self, data_dir):
|
||||
"""Test loading multiple image files."""
|
||||
@@ -252,12 +266,10 @@ class TestImageUploadWithDataFiles:
|
||||
messages = builder.load(tmp_dir)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0]["content"], list)
|
||||
assert len(messages[0]["content"]) == 3
|
||||
assert all('image_url' in item for item in messages[0]["content"])
|
||||
assert all('url' in item["image_url"] for item in messages[0]["content"])
|
||||
assert all("data:image/" in item["image_url"]["url"] for item in messages[0]["content"])
|
||||
assert all("base64," in item["image_url"]["url"] for item in messages[0]["content"])
|
||||
image_parts = _image_parts(messages[0])
|
||||
assert len(image_parts) == 3
|
||||
assert all("data:image/" in part.url for part in image_parts)
|
||||
assert all("base64," in part.url for part in image_parts)
|
||||
|
||||
def test_image_file_content_integrity(self, data_dir):
|
||||
"""Test that image file content is preserved exactly."""
|
||||
@@ -270,11 +282,11 @@ class TestImageUploadWithDataFiles:
|
||||
messages = builder.load(tmp_file.name)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0]["content"], list)
|
||||
assert len(messages[0]["content"]) == 1
|
||||
image_parts = _image_parts(messages[0])
|
||||
assert len(image_parts) == 1
|
||||
|
||||
# Extract base64 data from result
|
||||
url = messages[0]["content"][0]["image_url"]["url"]
|
||||
url = image_parts[0].url
|
||||
base64_data = url.split("base64,")[1]
|
||||
import base64
|
||||
|
||||
@@ -304,12 +316,10 @@ class TestImageUploadWithDataFiles:
|
||||
messages = builder.load(tmp_dir)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0]["content"], list)
|
||||
assert len(messages[0]["content"]) == 8
|
||||
assert all('image_url' in item for item in messages[0]["content"])
|
||||
assert all('url' in item["image_url"] for item in messages[0]["content"])
|
||||
assert all("data:image/" in item["image_url"]["url"] for item in messages[0]["content"])
|
||||
assert all("base64," in item["image_url"]["url"] for item in messages[0]["content"])
|
||||
image_parts = _image_parts(messages[0])
|
||||
assert len(image_parts) == 8
|
||||
assert all("data:image/" in part.url for part in image_parts)
|
||||
assert all("base64," in part.url for part in image_parts)
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:.*Unsupported file types detected!.*")
|
||||
def test_image_unsupported_file_handling(self, data_dir):
|
||||
@@ -327,12 +337,10 @@ class TestImageUploadWithDataFiles:
|
||||
messages = builder.load(tmp_dir)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0]["content"], list)
|
||||
assert len(messages[0]["content"]) == 1
|
||||
assert 'image_url' in messages[0]["content"][0]
|
||||
assert 'url' in messages[0]["content"][0]["image_url"]
|
||||
assert "data:image/" in messages[0]["content"][0]["image_url"]["url"]
|
||||
assert "base64," in messages[0]["content"][0]["image_url"]["url"]
|
||||
image_parts = _image_parts(messages[0])
|
||||
assert len(image_parts) == 1
|
||||
assert "data:image/" in image_parts[0].url
|
||||
assert "base64," in image_parts[0].url
|
||||
|
||||
def test_image_case_insensitive_extensions(self, data_dir):
|
||||
"""Test that image file extensions are handled case-insensitively."""
|
||||
@@ -357,9 +365,7 @@ class TestImageUploadWithDataFiles:
|
||||
messages = builder.load(tmp_dir)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert isinstance(messages[0]["content"], list)
|
||||
assert len(messages[0]["content"]) == 8
|
||||
assert all('image_url' in item for item in messages[0]["content"])
|
||||
assert all('url' in item["image_url"] for item in messages[0]["content"])
|
||||
assert all("data:image/" in item["image_url"]["url"] for item in messages[0]["content"])
|
||||
assert all("base64," in item["image_url"]["url"] for item in messages[0]["content"])
|
||||
image_parts = _image_parts(messages[0])
|
||||
assert len(image_parts) == 8
|
||||
assert all("data:image/" in part.url for part in image_parts)
|
||||
assert all("base64," in part.url for part in image_parts)
|
||||
|
||||
@@ -3,6 +3,9 @@ from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
import ramalama.transports.transport_factory as transport_factory_module
|
||||
from ramalama.chat_providers.openai import OpenAIResponsesChatProvider
|
||||
from ramalama.transports.api import APITransport
|
||||
from ramalama.transports.huggingface import Huggingface
|
||||
from ramalama.transports.modelscope import ModelScope
|
||||
from ramalama.transports.oci import OCI
|
||||
@@ -35,6 +38,7 @@ hf_granite_blob = "https://huggingface.co/ibm-granite/granite-3b-code-base-2k-GG
|
||||
"input,expected,error",
|
||||
[
|
||||
(Input("", "", ""), None, KeyError),
|
||||
(Input("openai://gpt-4o-mini", "", ""), APITransport, None),
|
||||
(Input("huggingface://granite-code", "", ""), Huggingface, None),
|
||||
(Input("hf://granite-code", "", ""), Huggingface, None),
|
||||
(Input("hf.co/granite-code", "", ""), Huggingface, None),
|
||||
@@ -113,6 +117,7 @@ def test_validate_oci_model_input(input: Input, error):
|
||||
@pytest.mark.parametrize(
|
||||
"input,expected",
|
||||
[
|
||||
(Input("openai://gpt-4o-mini", "", ""), "gpt-4o-mini"),
|
||||
(Input("huggingface://granite-code", "", ""), "granite-code"),
|
||||
(
|
||||
Input("huggingface://ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""),
|
||||
@@ -164,3 +169,21 @@ def test_prune_model_input(input: Input, expected: str):
|
||||
args = ARGS(input.Engine)
|
||||
pruned_model_input = TransportFactory(input.Model, args, input.Transport).prune_model_input()
|
||||
assert pruned_model_input == expected
|
||||
|
||||
|
||||
def test_transport_factory_passes_scheme_to_get_chat_provider(monkeypatch):
|
||||
args = ARGS()
|
||||
provider = OpenAIResponsesChatProvider("https://api.openai.com/v1")
|
||||
captured: dict[str, str] = {}
|
||||
|
||||
def fake_get_chat_provider(scheme: str):
|
||||
captured["scheme"] = scheme
|
||||
return provider
|
||||
|
||||
monkeypatch.setattr(transport_factory_module, "get_chat_provider", fake_get_chat_provider)
|
||||
|
||||
transport = TransportFactory("openai://gpt-4o-mini", args).create()
|
||||
|
||||
assert captured["scheme"] == "openai"
|
||||
assert isinstance(transport, APITransport)
|
||||
assert transport.provider is provider
|
||||
|
||||
Reference in New Issue
Block a user