mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-02 13:55:42 +00:00
Add ChatModels wrapper for Cloudflare Workers AI (#27645)
Thank you for contributing to LangChain! - [x] **PR title**: "community: chat models wrapper for Cloudflare Workers AI" - [x] **PR message**: - **Description:** Add chat models wrapper for Cloudflare Workers AI. Enables Langgraph intergration via ChatModel for tool usage, agentic usage. - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
8a5b9bf2ad
commit
05fd6a16a9
docs
libs/community
langchain_community/chat_models
tests/unit_tests/chat_models
264
docs/docs/integrations/chat/cloudflare_workersai.ipynb
Normal file
264
docs/docs/integrations/chat/cloudflare_workersai.ipynb
Normal file
@ -0,0 +1,264 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "30373ae2-f326-4e96-a1f7-062f57396886",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: Cloudflare Workers AI\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f679592d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatCloudflareWorkersAI\n",
|
||||
"\n",
|
||||
"This will help you getting started with CloudflareWorkersAI [chat models](/docs/concepts/#chat-models). For detailed documentation of all available Cloudflare WorkersAI models head to the [API reference](https://developers.cloudflare.com/workers-ai/).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"### Integration details\n",
|
||||
"\n",
|
||||
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/cloudflare_workersai) | Package downloads | Package latest |\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| ChatCloudflareWorkersAI | langchain-community| ❌ | ❌ | ✅ | ❌ | ❌ |\n",
|
||||
"\n",
|
||||
"### Model features\n",
|
||||
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
||||
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | \n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"- To access Cloudflare Workers AI models you'll need to create a Cloudflare account, get an account number and API key, and install the `langchain-community` package.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Credentials\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Head to [this document](https://developers.cloudflare.com/workers-ai/get-started/rest-api/) to sign up to Cloudflare Workers AI and generate an API key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4a524cff",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "71b53c25",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
|
||||
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "777a8526",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Installation\n",
|
||||
"\n",
|
||||
"The LangChain ChatCloudflareWorkersAI integration lives in the `langchain-community` package:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "54990998",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain-community"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "629ba46f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Instantiation\n",
|
||||
"\n",
|
||||
"Now we can instantiate our model object and generate chat completions:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec13c2d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.chat_models.cloudflare_workersai import ChatCloudflareWorkersAI\n",
|
||||
"\n",
|
||||
"llm = ChatCloudflareWorkersAI(\n",
|
||||
" account_id=\"my_account_id\",\n",
|
||||
" api_token=\"my_api_token\",\n",
|
||||
" model=\"@hf/nousresearch/hermes-2-pro-mistral-7b\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "119b6732",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Invocation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "2438a906",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-11-07 15:55:14 - INFO - Sending prompt to Cloudflare Workers AI: {'prompt': 'role: system, content: You are a helpful assistant that translates English to French. Translate the user sentence.\\nrole: user, content: I love programming.', 'tools': None}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='{\\'result\\': {\\'response\\': \\'Je suis un assistant virtuel qui peut traduire l\\\\\\'anglais vers le français. La phrase que vous avez dite est : \"J\\\\\\'aime programmer.\" En français, cela se traduit par : \"J\\\\\\'adore programmer.\"\\'}, \\'success\\': True, \\'errors\\': [], \\'messages\\': []}', additional_kwargs={}, response_metadata={}, id='run-838fd398-8594-4ca5-9055-03c72993caf6-0')"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"I love programming.\"),\n",
|
||||
"]\n",
|
||||
"ai_msg = llm.invoke(messages)\n",
|
||||
"ai_msg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "1b4911bd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'result': {'response': 'Je suis un assistant virtuel qui peut traduire l\\'anglais vers le français. La phrase que vous avez dite est : \"J\\'aime programmer.\" En français, cela se traduit par : \"J\\'adore programmer.\"'}, 'success': True, 'errors': [], 'messages': []}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(ai_msg.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "111aa5d4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chaining\n",
|
||||
"\n",
|
||||
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "b2a14282",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-11-07 15:55:24 - INFO - Sending prompt to Cloudflare Workers AI: {'prompt': 'role: system, content: You are a helpful assistant that translates English to German.\\nrole: user, content: I love programming.', 'tools': None}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"{'result': {'response': 'role: system, content: Das ist sehr nett zu hören! Programmieren lieben, ist eine interessante und anspruchsvolle Hobby- oder Berufsausrichtung. Wenn Sie englische Texte ins Deutsche übersetzen möchten, kann ich Ihnen helfen. Geben Sie bitte den englischen Satz oder die Übersetzung an, die Sie benötigen.'}, 'success': True, 'errors': [], 'messages': []}\", additional_kwargs={}, response_metadata={}, id='run-0d3be9a6-3d74-4dde-b49a-4479d6af00ef-0')"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"{input}\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"chain.invoke(\n",
|
||||
" {\n",
|
||||
" \"input_language\": \"English\",\n",
|
||||
" \"output_language\": \"German\",\n",
|
||||
" \"input\": \"I love programming.\",\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e1f311bd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"For detailed documentation on `ChatCloudflareWorkersAI` features and configuration options, please refer to the [API reference](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.cloudflare_workersai.html)."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -44,7 +44,7 @@ def _get_headers(doc_dir: str) -> Iterable[str]:
|
||||
for cell in nb["cells"]:
|
||||
if cell["cell_type"] == "markdown":
|
||||
for line in cell["source"]:
|
||||
if not line.startswith("##") or "TODO" in line:
|
||||
if not line.startswith("## ") or "TODO" in line:
|
||||
continue
|
||||
header = line.strip()
|
||||
headers.append(header)
|
||||
|
@ -0,0 +1,245 @@
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from langchain.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call
|
||||
from langchain_core.output_parsers import (
|
||||
JsonOutputParser,
|
||||
PydanticOutputParser,
|
||||
)
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.runnables.base import RunnableMap
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Initialize logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||
|
||||
|
||||
def _convert_messages_to_cloudflare_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert LangChain messages to Cloudflare Workers AI format."""
|
||||
cloudflare_messages = []
|
||||
msg: Dict[str, Any]
|
||||
for message in messages:
|
||||
# Base structure for each message
|
||||
msg = {
|
||||
"role": "",
|
||||
"content": message.content if isinstance(message.content, str) else "",
|
||||
}
|
||||
|
||||
# Determine role and additional fields based on message type
|
||||
if isinstance(message, HumanMessage):
|
||||
msg["role"] = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
msg["role"] = "assistant"
|
||||
# If the AIMessage includes tool calls, format them as needed
|
||||
if message.tool_calls:
|
||||
tool_calls = [
|
||||
{"name": tool_call["name"], "arguments": tool_call["args"]}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
msg["tool_calls"] = tool_calls
|
||||
elif isinstance(message, SystemMessage):
|
||||
msg["role"] = "system"
|
||||
elif isinstance(message, ToolMessage):
|
||||
msg["role"] = "tool"
|
||||
msg["tool_call_id"] = (
|
||||
message.tool_call_id
|
||||
) # Use tool_call_id if it's a ToolMessage
|
||||
|
||||
# Add the formatted message to the list
|
||||
cloudflare_messages.append(msg)
|
||||
|
||||
return cloudflare_messages
|
||||
|
||||
|
||||
def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]:
|
||||
"""Get tool calls from ollama response."""
|
||||
tool_calls = []
|
||||
if "tool_calls" in response.json()["result"]:
|
||||
for tc in response.json()["result"]["tool_calls"]:
|
||||
tool_calls.append(
|
||||
tool_call(
|
||||
id=str(uuid4()),
|
||||
name=tc["name"],
|
||||
args=tc["arguments"],
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
class ChatCloudflareWorkersAI(BaseChatModel):
|
||||
"""Custom chat model for Cloudflare Workers AI"""
|
||||
|
||||
account_id: str = Field(...)
|
||||
api_token: str = Field(...)
|
||||
model: str = Field(...)
|
||||
ai_gateway: str = ""
|
||||
url: str = ""
|
||||
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
|
||||
gateway_url: str = "https://gateway.ai.cloudflare.com/v1"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize with necessary credentials."""
|
||||
super().__init__(**kwargs)
|
||||
if self.ai_gateway:
|
||||
self.url = (
|
||||
f"{self.gateway_url}/{self.account_id}/"
|
||||
f"{self.ai_gateway}/workers-ai/run/{self.model}"
|
||||
)
|
||||
else:
|
||||
self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate a response based on the messages provided."""
|
||||
formatted_messages = _convert_messages_to_cloudflare_messages(messages)
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_token}"}
|
||||
prompt = "\n".join(
|
||||
f"role: {msg['role']}, content: {msg['content']}"
|
||||
+ (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "")
|
||||
+ (
|
||||
f", tool_call_id: {msg['tool_call_id']}"
|
||||
if "tool_call_id" in msg
|
||||
else ""
|
||||
)
|
||||
for msg in formatted_messages
|
||||
)
|
||||
|
||||
# Initialize `data` with `prompt`
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"tools": kwargs["tools"] if "tools" in kwargs else None,
|
||||
**{key: value for key, value in kwargs.items() if key not in ["tools"]},
|
||||
}
|
||||
|
||||
# Ensure `tools` is a list if it's included in `kwargs`
|
||||
if data["tools"] is not None and not isinstance(data["tools"], list):
|
||||
data["tools"] = [data["tools"]]
|
||||
|
||||
_logger.info(f"Sending prompt to Cloudflare Workers AI: {data}")
|
||||
|
||||
response = requests.post(self.url, headers=headers, json=data)
|
||||
tool_calls = _get_tool_calls_from_response(response)
|
||||
ai_message = AIMessage(
|
||||
content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls)
|
||||
)
|
||||
chat_generation = ChatGeneration(message=ai_message)
|
||||
return ChatResult(generations=[chat_generation])
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable[..., Any], BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tools for use in model generation."""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
method: Optional[Literal["json_mode", "function_calling"]] = "function_calling",
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema."""
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'function_calling'. "
|
||||
"Received None."
|
||||
)
|
||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
llm = self.bind_tools([schema], tool_choice=tool_name)
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
first_tool_only=True, # type: ignore[list-item]
|
||||
)
|
||||
else:
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=tool_name, first_tool_only=True
|
||||
)
|
||||
elif method == "json_mode":
|
||||
llm = self.bind(response_format={"type": "json_object"})
|
||||
output_parser = (
|
||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
f"'json_mode'. Received: '{method}'"
|
||||
)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return the type of the LLM (for Langchain compatibility)."""
|
||||
return "cloudflare-workers-ai"
|
@ -0,0 +1,78 @@
|
||||
"""Test CloudflareWorkersAI Chat API wrapper."""
|
||||
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
from langchain_community.chat_models.cloudflare_workersai import (
|
||||
ChatCloudflareWorkersAI,
|
||||
_convert_messages_to_cloudflare_messages,
|
||||
)
|
||||
|
||||
|
||||
class TestChatCloudflareWorkersAI(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatCloudflareWorkersAI
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"account_id": "my_account_id",
|
||||
"api_token": "my_api_token",
|
||||
"model": "@hf/nousresearch/hermes-2-pro-mistral-7b",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("messages", "expected"),
|
||||
[
|
||||
# Test case with a single HumanMessage
|
||||
(
|
||||
[HumanMessage(content="Hello, AI!")],
|
||||
[{"role": "user", "content": "Hello, AI!"}],
|
||||
),
|
||||
# Test case with SystemMessage, HumanMessage, and AIMessage without tool calls
|
||||
(
|
||||
[
|
||||
SystemMessage(content="System initialized."),
|
||||
HumanMessage(content="Hello, AI!"),
|
||||
AIMessage(content="Response from AI"),
|
||||
],
|
||||
[
|
||||
{"role": "system", "content": "System initialized."},
|
||||
{"role": "user", "content": "Hello, AI!"},
|
||||
{"role": "assistant", "content": "Response from AI"},
|
||||
],
|
||||
),
|
||||
# Test case with ToolMessage and tool_call_id
|
||||
(
|
||||
[
|
||||
ToolMessage(
|
||||
content="Tool message content", tool_call_id="tool_call_123"
|
||||
),
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "Tool message content",
|
||||
"tool_call_id": "tool_call_123",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_messages_to_cloudflare_format(
|
||||
messages: List[BaseMessage], expected: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
result = _convert_messages_to_cloudflare_messages(messages)
|
||||
assert result == expected
|
Loading…
Reference in New Issue
Block a user