Add ChatModels wrapper for Cloudflare Workers AI (#27645)

Thank you for contributing to LangChain!

- [x] **PR title**: "community: chat models wrapper for Cloudflare
Workers AI"


- [x] **PR message**:
- **Description:** Add chat models wrapper for Cloudflare Workers AI.
Enables Langgraph intergration via ChatModel for tool usage, agentic
usage.


- [x] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Akshata
2024-11-07 14:34:24 -06:00
committed by GitHub
parent 8a5b9bf2ad
commit 05fd6a16a9
4 changed files with 588 additions and 1 deletions

View File

@@ -0,0 +1,78 @@
"""Test CloudflareWorkersAI Chat API wrapper."""
from typing import Any, Dict, List, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_community.chat_models.cloudflare_workersai import (
ChatCloudflareWorkersAI,
_convert_messages_to_cloudflare_messages,
)
class TestChatCloudflareWorkersAI(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatCloudflareWorkersAI
@property
def chat_model_params(self) -> dict:
return {
"account_id": "my_account_id",
"api_token": "my_api_token",
"model": "@hf/nousresearch/hermes-2-pro-mistral-7b",
}
@pytest.mark.parametrize(
("messages", "expected"),
[
# Test case with a single HumanMessage
(
[HumanMessage(content="Hello, AI!")],
[{"role": "user", "content": "Hello, AI!"}],
),
# Test case with SystemMessage, HumanMessage, and AIMessage without tool calls
(
[
SystemMessage(content="System initialized."),
HumanMessage(content="Hello, AI!"),
AIMessage(content="Response from AI"),
],
[
{"role": "system", "content": "System initialized."},
{"role": "user", "content": "Hello, AI!"},
{"role": "assistant", "content": "Response from AI"},
],
),
# Test case with ToolMessage and tool_call_id
(
[
ToolMessage(
content="Tool message content", tool_call_id="tool_call_123"
),
],
[
{
"role": "tool",
"content": "Tool message content",
"tool_call_id": "tool_call_123",
}
],
),
],
)
def test_convert_messages_to_cloudflare_format(
messages: List[BaseMessage], expected: List[Dict[str, Any]]
) -> None:
result = _convert_messages_to_cloudflare_messages(messages)
assert result == expected