mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
community[minor]: Dappier chat model integration (#19370)
**Description:** This PR adds [Dappier](https://dappier.com/) for the chat model. It supports generate, async generate, and batch functionalities. We added unit and integration tests as well as a notebook with more details about our chat model. **Dependencies:** No extra dependencies are needed.
This commit is contained in:
committed by
GitHub
parent
64e1df3d3a
commit
743f888580
161
libs/community/langchain_community/chat_models/dappier.py
Normal file
161
libs/community/langchain_community/chat_models/dappier.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
|
||||
def _format_dappier_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]:
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message.type == "human":
|
||||
formatted_messages.append({"role": "user", "content": message.content})
|
||||
elif message.type == "system":
|
||||
formatted_messages.append({"role": "system", "content": message.content})
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
||||
class ChatDappierAI(BaseChatModel):
|
||||
"""`Dappier` chat large language models.
|
||||
|
||||
`Dappier` is a platform enabling access to diverse, real-time data models.
|
||||
Enhance your AI applications with Dappier's pre-trained, LLM-ready data models
|
||||
and ensure accurate, current responses with reduced inaccuracies.
|
||||
|
||||
To use one of our Dappier AI Data Models, you will need an API key.
|
||||
Please visit Dappier Platform (https://platform.dappier.com/) to log in
|
||||
and create an API key in your profile.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatDappierAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
# Initialize `ChatDappierAI` with the desired configuration
|
||||
chat = ChatDappierAI(
|
||||
dappier_endpoint="https://api.dappier.com/app/datamodel/dm_01hpsxyfm2fwdt2zet9cg6fdxt",
|
||||
dappier_api_key="<YOUR_KEY>")
|
||||
|
||||
# Create a list of messages to interact with the model
|
||||
messages = [HumanMessage(content="hello")]
|
||||
|
||||
# Invoke the model with the provided messages
|
||||
chat.invoke(messages)
|
||||
|
||||
|
||||
you can find more details here : https://docs.dappier.com/introduction"""
|
||||
|
||||
dappier_endpoint: str = "https://api.dappier.com/app/datamodelconversation"
|
||||
|
||||
dappier_model: str = "dm_01hpsxyfm2fwdt2zet9cg6fdxt"
|
||||
|
||||
dappier_api_key: Optional[SecretStr] = Field(None, description="Dappier API Token")
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["dappier_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "dappier_api_key", "DAPPIER_API_KEY")
|
||||
)
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain_community import __version__
|
||||
|
||||
return f"langchain/{__version__}"
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "dappier-realtimesearch-chat"
|
||||
|
||||
@property
|
||||
def _api_key(self) -> str:
|
||||
if self.dappier_api_key:
|
||||
return self.dappier_api_key.get_secret_value()
|
||||
return ""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
url = f"{self.dappier_endpoint}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
user_query = _format_dappier_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"model": self.dappier_model,
|
||||
"conversation": user_query,
|
||||
}
|
||||
|
||||
request = Requests(headers=headers)
|
||||
response = request.post(url=url, data=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
message_response = data["message"]
|
||||
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=message_response))]
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
url = f"{self.dappier_endpoint}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"User-Agent": self.get_user_agent(),
|
||||
}
|
||||
user_query = _format_dappier_messages(messages=messages)
|
||||
payload: Dict[str, Any] = {
|
||||
"model": self.dappier_model,
|
||||
"conversation": user_query,
|
||||
}
|
||||
|
||||
async with ClientSession() as session:
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
message_response = data["message"]
|
||||
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=AIMessage(content=message_response))
|
||||
]
|
||||
)
|
@@ -0,0 +1,58 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
from langchain_community.chat_models.dappier import (
|
||||
ChatDappierAI,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_dappier_chat() -> None:
|
||||
"""Test ChatDappierAI wrapper."""
|
||||
chat = ChatDappierAI(
|
||||
dappier_endpoint="https://api.dappier.com/app/datamodelconversation",
|
||||
dappier_model="dm_01hpsxyfm2fwdt2zet9cg6fdxt",
|
||||
)
|
||||
message = HumanMessage(content="Who are you ?")
|
||||
response = chat([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_dappier_generate() -> None:
|
||||
"""Test generate method of Dappier AI."""
|
||||
chat = ChatDappierAI(
|
||||
dappier_endpoint="https://api.dappier.com/app/datamodelconversation",
|
||||
dappier_model="dm_01hpsxyfm2fwdt2zet9cg6fdxt",
|
||||
)
|
||||
chat_messages: List[List[BaseMessage]] = [
|
||||
[HumanMessage(content="Who won the last super bowl?")],
|
||||
]
|
||||
messages_copy = [messages.copy() for messages in chat_messages]
|
||||
result: LLMResult = chat.generate(chat_messages)
|
||||
assert isinstance(result, LLMResult)
|
||||
for response in result.generations[0]:
|
||||
assert isinstance(response, ChatGeneration)
|
||||
assert isinstance(response.text, str)
|
||||
assert response.text == response.message.content
|
||||
assert chat_messages == messages_copy
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_dappier_agenerate() -> None:
|
||||
"""Test async generation."""
|
||||
chat = ChatDappierAI(
|
||||
dappier_endpoint="https://api.dappier.com/app/datamodelconversation",
|
||||
dappier_model="dm_01hpsxyfm2fwdt2zet9cg6fdxt",
|
||||
)
|
||||
message = HumanMessage(content="Who won the last super bowl?")
|
||||
result: LLMResult = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(result, LLMResult)
|
||||
for response in result.generations[0]:
|
||||
assert isinstance(response, ChatGeneration)
|
||||
assert isinstance(response.text, str)
|
||||
assert response.text == response.message.content
|
34
libs/community/tests/unit_tests/chat_models/test_dappier.py
Normal file
34
libs/community/tests/unit_tests/chat_models/test_dappier.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Test EdenAI Chat API wrapper."""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_community.chat_models.dappier import _format_dappier_messages
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("messages", "expected"),
|
||||
[
|
||||
(
|
||||
[
|
||||
SystemMessage(
|
||||
content="You are a chat model with real time search tools"
|
||||
),
|
||||
HumanMessage(content="Hello how are you today?"),
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a chat model with real time search tools",
|
||||
},
|
||||
{"role": "user", "content": "Hello how are you today?"},
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_dappier_messages_formatting(
|
||||
messages: List[BaseMessage], expected: str
|
||||
) -> None:
|
||||
result = _format_dappier_messages(messages)
|
||||
assert result == expected
|
Reference in New Issue
Block a user