1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-02 13:55:42 +00:00

Add ChatModels wrapper for Cloudflare Workers AI ()

Thank you for contributing to LangChain!

- [x] **PR title**: "community: chat models wrapper for Cloudflare
Workers AI"


- [x] **PR message**:
- **Description:** Add chat models wrapper for Cloudflare Workers AI.
Enables Langgraph intergration via ChatModel for tool usage, agentic
usage.


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Akshata 2024-11-07 14:34:24 -06:00 committed by GitHub
parent 8a5b9bf2ad
commit 05fd6a16a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 588 additions and 1 deletions
docs
libs/community
langchain_community/chat_models
tests/unit_tests/chat_models

View File

@ -0,0 +1,264 @@
{
"cells": [
{
"cell_type": "raw",
"id": "30373ae2-f326-4e96-a1f7-062f57396886",
"metadata": {},
"source": [
"---\n",
"sidebar_label: Cloudflare Workers AI\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "f679592d",
"metadata": {},
"source": [
"# ChatCloudflareWorkersAI\n",
"\n",
"This will help you getting started with CloudflareWorkersAI [chat models](/docs/concepts/#chat-models). For detailed documentation of all available Cloudflare WorkersAI models head to the [API reference](https://developers.cloudflare.com/workers-ai/).\n",
"\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/cloudflare_workersai) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| ChatCloudflareWorkersAI | langchain-community| ❌ | ❌ | ✅ | ❌ | ❌ |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | \n",
"\n",
"## Setup\n",
"\n",
"- To access Cloudflare Workers AI models you'll need to create a Cloudflare account, get an account number and API key, and install the `langchain-community` package.\n",
"\n",
"\n",
"### Credentials\n",
"\n",
"\n",
"Head to [this document](https://developers.cloudflare.com/workers-ai/get-started/rest-api/) to sign up to Cloudflare Workers AI and generate an API key."
]
},
{
"cell_type": "markdown",
"id": "4a524cff",
"metadata": {},
"source": [
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "71b53c25",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")"
]
},
{
"cell_type": "markdown",
"id": "777a8526",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The LangChain ChatCloudflareWorkersAI integration lives in the `langchain-community` package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54990998",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-community"
]
},
{
"cell_type": "markdown",
"id": "629ba46f",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec13c2d9",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.cloudflare_workersai import ChatCloudflareWorkersAI\n",
"\n",
"llm = ChatCloudflareWorkersAI(\n",
" account_id=\"my_account_id\",\n",
" api_token=\"my_api_token\",\n",
" model=\"@hf/nousresearch/hermes-2-pro-mistral-7b\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "119b6732",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2438a906",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-07 15:55:14 - INFO - Sending prompt to Cloudflare Workers AI: {'prompt': 'role: system, content: You are a helpful assistant that translates English to French. Translate the user sentence.\\nrole: user, content: I love programming.', 'tools': None}\n"
]
},
{
"data": {
"text/plain": [
"AIMessage(content='{\\'result\\': {\\'response\\': \\'Je suis un assistant virtuel qui peut traduire l\\\\\\'anglais vers le français. La phrase que vous avez dite est : \"J\\\\\\'aime programmer.\" En français, cela se traduit par : \"J\\\\\\'adore programmer.\"\\'}, \\'success\\': True, \\'errors\\': [], \\'messages\\': []}', additional_kwargs={}, response_metadata={}, id='run-838fd398-8594-4ca5-9055-03c72993caf6-0')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1b4911bd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'result': {'response': 'Je suis un assistant virtuel qui peut traduire l\\'anglais vers le français. La phrase que vous avez dite est : \"J\\'aime programmer.\" En français, cela se traduit par : \"J\\'adore programmer.\"'}, 'success': True, 'errors': [], 'messages': []}\n"
]
}
],
"source": [
"print(ai_msg.content)"
]
},
{
"cell_type": "markdown",
"id": "111aa5d4",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b2a14282",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-11-07 15:55:24 - INFO - Sending prompt to Cloudflare Workers AI: {'prompt': 'role: system, content: You are a helpful assistant that translates English to German.\\nrole: user, content: I love programming.', 'tools': None}\n"
]
},
{
"data": {
"text/plain": [
"AIMessage(content=\"{'result': {'response': 'role: system, content: Das ist sehr nett zu hören! Programmieren lieben, ist eine interessante und anspruchsvolle Hobby- oder Berufsausrichtung. Wenn Sie englische Texte ins Deutsche übersetzen möchten, kann ich Ihnen helfen. Geben Sie bitte den englischen Satz oder die Übersetzung an, die Sie benötigen.'}, 'success': True, 'errors': [], 'messages': []}\", additional_kwargs={}, response_metadata={}, id='run-0d3be9a6-3d74-4dde-b49a-4479d6af00ef-0')"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "e1f311bd",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation on `ChatCloudflareWorkersAI` features and configuration options, please refer to the [API reference](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.cloudflare_workersai.html)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -44,7 +44,7 @@ def _get_headers(doc_dir: str) -> Iterable[str]:
for cell in nb["cells"]:
if cell["cell_type"] == "markdown":
for line in cell["source"]:
if not line.startswith("##") or "TODO" in line:
if not line.startswith("## ") or "TODO" in line:
continue
header = line.strip()
headers.append(header)

View File

@ -0,0 +1,245 @@
import logging
from operator import itemgetter
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
from uuid import uuid4
import requests
from langchain.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableMap
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field
# Initialize logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_logger = logging.getLogger(__name__)
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)
def _convert_messages_to_cloudflare_messages(
messages: List[BaseMessage],
) -> List[Dict[str, Any]]:
"""Convert LangChain messages to Cloudflare Workers AI format."""
cloudflare_messages = []
msg: Dict[str, Any]
for message in messages:
# Base structure for each message
msg = {
"role": "",
"content": message.content if isinstance(message.content, str) else "",
}
# Determine role and additional fields based on message type
if isinstance(message, HumanMessage):
msg["role"] = "user"
elif isinstance(message, AIMessage):
msg["role"] = "assistant"
# If the AIMessage includes tool calls, format them as needed
if message.tool_calls:
tool_calls = [
{"name": tool_call["name"], "arguments": tool_call["args"]}
for tool_call in message.tool_calls
]
msg["tool_calls"] = tool_calls
elif isinstance(message, SystemMessage):
msg["role"] = "system"
elif isinstance(message, ToolMessage):
msg["role"] = "tool"
msg["tool_call_id"] = (
message.tool_call_id
) # Use tool_call_id if it's a ToolMessage
# Add the formatted message to the list
cloudflare_messages.append(msg)
return cloudflare_messages
def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]:
"""Get tool calls from ollama response."""
tool_calls = []
if "tool_calls" in response.json()["result"]:
for tc in response.json()["result"]["tool_calls"]:
tool_calls.append(
tool_call(
id=str(uuid4()),
name=tc["name"],
args=tc["arguments"],
)
)
return tool_calls
class ChatCloudflareWorkersAI(BaseChatModel):
"""Custom chat model for Cloudflare Workers AI"""
account_id: str = Field(...)
api_token: str = Field(...)
model: str = Field(...)
ai_gateway: str = ""
url: str = ""
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
gateway_url: str = "https://gateway.ai.cloudflare.com/v1"
def __init__(self, **kwargs: Any) -> None:
"""Initialize with necessary credentials."""
super().__init__(**kwargs)
if self.ai_gateway:
self.url = (
f"{self.gateway_url}/{self.account_id}/"
f"{self.ai_gateway}/workers-ai/run/{self.model}"
)
else:
self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response based on the messages provided."""
formatted_messages = _convert_messages_to_cloudflare_messages(messages)
headers = {"Authorization": f"Bearer {self.api_token}"}
prompt = "\n".join(
f"role: {msg['role']}, content: {msg['content']}"
+ (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "")
+ (
f", tool_call_id: {msg['tool_call_id']}"
if "tool_call_id" in msg
else ""
)
for msg in formatted_messages
)
# Initialize `data` with `prompt`
data = {
"prompt": prompt,
"tools": kwargs["tools"] if "tools" in kwargs else None,
**{key: value for key, value in kwargs.items() if key not in ["tools"]},
}
# Ensure `tools` is a list if it's included in `kwargs`
if data["tools"] is not None and not isinstance(data["tools"], list):
data["tools"] = [data["tools"]]
_logger.info(f"Sending prompt to Cloudflare Workers AI: {data}")
response = requests.post(self.url, headers=headers, json=data)
tool_calls = _get_tool_calls_from_response(response)
ai_message = AIMessage(
content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls)
)
chat_generation = ChatGeneration(message=ai_message)
return ChatResult(generations=[chat_generation])
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable[..., Any], BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools for use in model generation."""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
method: Optional[Literal["json_mode", "function_calling"]] = "function_calling",
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema."""
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
@property
def _llm_type(self) -> str:
"""Return the type of the LLM (for Langchain compatibility)."""
return "cloudflare-workers-ai"

View File

@ -0,0 +1,78 @@
"""Test CloudflareWorkersAI Chat API wrapper."""
from typing import Any, Dict, List, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_community.chat_models.cloudflare_workersai import (
ChatCloudflareWorkersAI,
_convert_messages_to_cloudflare_messages,
)
class TestChatCloudflareWorkersAI(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatCloudflareWorkersAI
@property
def chat_model_params(self) -> dict:
return {
"account_id": "my_account_id",
"api_token": "my_api_token",
"model": "@hf/nousresearch/hermes-2-pro-mistral-7b",
}
@pytest.mark.parametrize(
("messages", "expected"),
[
# Test case with a single HumanMessage
(
[HumanMessage(content="Hello, AI!")],
[{"role": "user", "content": "Hello, AI!"}],
),
# Test case with SystemMessage, HumanMessage, and AIMessage without tool calls
(
[
SystemMessage(content="System initialized."),
HumanMessage(content="Hello, AI!"),
AIMessage(content="Response from AI"),
],
[
{"role": "system", "content": "System initialized."},
{"role": "user", "content": "Hello, AI!"},
{"role": "assistant", "content": "Response from AI"},
],
),
# Test case with ToolMessage and tool_call_id
(
[
ToolMessage(
content="Tool message content", tool_call_id="tool_call_123"
),
],
[
{
"role": "tool",
"content": "Tool message content",
"tool_call_id": "tool_call_123",
}
],
),
],
)
def test_convert_messages_to_cloudflare_format(
messages: List[BaseMessage], expected: List[Dict[str, Any]]
) -> None:
result = _convert_messages_to_cloudflare_messages(messages)
assert result == expected