langchain/libs/community/langchain_community/chat_models/cloudflare_workersai.py
Akshata 05fd6a16a9
Add ChatModels wrapper for Cloudflare Workers AI (#27645)
Thank you for contributing to LangChain!

- [x] **PR title**: "community: chat models wrapper for Cloudflare
Workers AI"


- [x] **PR message**:
- **Description:** Add chat models wrapper for Cloudflare Workers AI.
Enables Langgraph intergration via ChatModel for tool usage, agentic
usage.


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-11-07 15:34:24 -05:00

246 lines
8.7 KiB
Python

import logging
from operator import itemgetter
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)
from uuid import uuid4
import requests
from langchain.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables.base import RunnableMap
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field
# Initialize logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_logger = logging.getLogger(__name__)
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)
def _convert_messages_to_cloudflare_messages(
messages: List[BaseMessage],
) -> List[Dict[str, Any]]:
"""Convert LangChain messages to Cloudflare Workers AI format."""
cloudflare_messages = []
msg: Dict[str, Any]
for message in messages:
# Base structure for each message
msg = {
"role": "",
"content": message.content if isinstance(message.content, str) else "",
}
# Determine role and additional fields based on message type
if isinstance(message, HumanMessage):
msg["role"] = "user"
elif isinstance(message, AIMessage):
msg["role"] = "assistant"
# If the AIMessage includes tool calls, format them as needed
if message.tool_calls:
tool_calls = [
{"name": tool_call["name"], "arguments": tool_call["args"]}
for tool_call in message.tool_calls
]
msg["tool_calls"] = tool_calls
elif isinstance(message, SystemMessage):
msg["role"] = "system"
elif isinstance(message, ToolMessage):
msg["role"] = "tool"
msg["tool_call_id"] = (
message.tool_call_id
) # Use tool_call_id if it's a ToolMessage
# Add the formatted message to the list
cloudflare_messages.append(msg)
return cloudflare_messages
def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]:
"""Get tool calls from ollama response."""
tool_calls = []
if "tool_calls" in response.json()["result"]:
for tc in response.json()["result"]["tool_calls"]:
tool_calls.append(
tool_call(
id=str(uuid4()),
name=tc["name"],
args=tc["arguments"],
)
)
return tool_calls
class ChatCloudflareWorkersAI(BaseChatModel):
"""Custom chat model for Cloudflare Workers AI"""
account_id: str = Field(...)
api_token: str = Field(...)
model: str = Field(...)
ai_gateway: str = ""
url: str = ""
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
gateway_url: str = "https://gateway.ai.cloudflare.com/v1"
def __init__(self, **kwargs: Any) -> None:
"""Initialize with necessary credentials."""
super().__init__(**kwargs)
if self.ai_gateway:
self.url = (
f"{self.gateway_url}/{self.account_id}/"
f"{self.ai_gateway}/workers-ai/run/{self.model}"
)
else:
self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response based on the messages provided."""
formatted_messages = _convert_messages_to_cloudflare_messages(messages)
headers = {"Authorization": f"Bearer {self.api_token}"}
prompt = "\n".join(
f"role: {msg['role']}, content: {msg['content']}"
+ (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "")
+ (
f", tool_call_id: {msg['tool_call_id']}"
if "tool_call_id" in msg
else ""
)
for msg in formatted_messages
)
# Initialize `data` with `prompt`
data = {
"prompt": prompt,
"tools": kwargs["tools"] if "tools" in kwargs else None,
**{key: value for key, value in kwargs.items() if key not in ["tools"]},
}
# Ensure `tools` is a list if it's included in `kwargs`
if data["tools"] is not None and not isinstance(data["tools"], list):
data["tools"] = [data["tools"]]
_logger.info(f"Sending prompt to Cloudflare Workers AI: {data}")
response = requests.post(self.url, headers=headers, json=data)
tool_calls = _get_tool_calls_from_response(response)
ai_message = AIMessage(
content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls)
)
chat_generation = ChatGeneration(message=ai_message)
return ChatResult(generations=[chat_generation])
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable[..., Any], BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools for use in model generation."""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
method: Optional[Literal["json_mode", "function_calling"]] = "function_calling",
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema."""
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
@property
def _llm_type(self) -> str:
"""Return the type of the LLM (for Langchain compatibility)."""
return "cloudflare-workers-ai"