[Community] PremAI Tool Calling Functionality (#23931)

This PR is under WIP and adds the following functionalities:

- [X] Supports tool calling across the langchain ecosystem. (However
streaming is not supported)
- [X] Update documentation
This commit is contained in:
Anindyadeep
2024-07-24 19:23:58 +05:30
committed by GitHub
parent e271965d1e
commit 12c3454fd9
4 changed files with 602 additions and 65 deletions

View File

@@ -12,6 +12,7 @@ from typing import (
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
@@ -20,6 +21,7 @@ from typing import (
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
@@ -33,6 +35,7 @@ from langchain_core.messages import (
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
@@ -41,7 +44,10 @@ from langchain_core.pydantic_v1 import (
Field,
SecretStr,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool
if TYPE_CHECKING:
from premai.api.chat_completions.v1_chat_completions_create import (
@@ -51,6 +57,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
TOOL_PROMPT_HEADER = """
Given the set of tools you used and the response, provide the final answer\n
"""
INTERMEDIATE_TOOL_RESULT_TEMPLATE = """
{json}
"""
SINGLE_TOOL_PROMPT_TEMPLATE = """
tool id: {tool_id}
tool_response: {tool_response}
"""
class ChatPremAPIError(Exception):
"""Error with the `PremAI` API."""
@@ -91,8 +110,22 @@ def _response_to_result(
raise ChatPremAPIError(f"ChatResponse must have a content: {content}")
if role == "assistant":
tool_calls = choice.message["tool_calls"]
if tool_calls is None:
tools = []
else:
tools = [
{
"id": tool_call["id"],
"name": tool_call["function"]["name"],
"args": tool_call["function"]["arguments"],
}
for tool_call in tool_calls
]
generations.append(
ChatGeneration(text=content, message=AIMessage(content=content))
ChatGeneration(
text=content, message=AIMessage(content=content, tool_calls=tools)
)
)
elif role == "user":
generations.append(
@@ -156,41 +189,65 @@ def _messages_to_prompt_dict(
system_prompt: Optional[str] = None
examples_and_messages: List[Dict[str, Any]] = []
if template_id is not None:
params: Dict[str, str] = {}
for input_msg in input_messages:
if isinstance(input_msg, SystemMessage):
system_prompt = str(input_msg.content)
for input_msg in input_messages:
if isinstance(input_msg, SystemMessage):
system_prompt = str(input_msg.content)
elif isinstance(input_msg, HumanMessage):
if template_id is None:
examples_and_messages.append(
{"role": "user", "content": str(input_msg.content)}
)
else:
params: Dict[str, str] = {}
assert (input_msg.id is not None) and (input_msg.id != ""), ValueError(
"When using prompt template there should be id associated ",
"with each HumanMessage",
)
params[str(input_msg.id)] = str(input_msg.content)
examples_and_messages.append(
{"role": "user", "template_id": template_id, "params": params}
)
for input_msg in input_messages:
if isinstance(input_msg, AIMessage):
examples_and_messages.append(
{"role": "assistant", "content": str(input_msg.content)}
{"role": "user", "template_id": template_id, "params": params}
)
else:
for input_msg in input_messages:
if isinstance(input_msg, SystemMessage):
system_prompt = str(input_msg.content)
elif isinstance(input_msg, HumanMessage):
examples_and_messages.append(
{"role": "user", "content": str(input_msg.content)}
)
elif isinstance(input_msg, AIMessage):
elif isinstance(input_msg, AIMessage):
if input_msg.tool_calls is None or len(input_msg.tool_calls) == 0:
examples_and_messages.append(
{"role": "assistant", "content": str(input_msg.content)}
)
else:
raise ChatPremAPIError("No such role explicitly exists")
ai_msg_to_json = {
"id": input_msg.id,
"content": input_msg.content,
"response_metadata": input_msg.response_metadata,
"tool_calls": input_msg.tool_calls,
}
examples_and_messages.append(
{
"role": "assistant",
"content": INTERMEDIATE_TOOL_RESULT_TEMPLATE.format(
json=ai_msg_to_json,
),
}
)
elif isinstance(input_msg, ToolMessage):
pass
else:
raise ChatPremAPIError("No such role explicitly exists")
# do a seperate search for tool calls
tool_prompt = ""
for input_msg in input_messages:
if isinstance(input_msg, ToolMessage):
tool_id = input_msg.tool_call_id
tool_result = input_msg.content
tool_prompt += SINGLE_TOOL_PROMPT_TEMPLATE.format(
tool_id=tool_id, tool_response=tool_result
)
if tool_prompt != "":
prompt = TOOL_PROMPT_HEADER
prompt += tool_prompt
examples_and_messages.append({"role": "user", "content": prompt})
return system_prompt, examples_and_messages
@@ -289,7 +346,6 @@ class ChatPremAI(BaseChatModel, BaseModel):
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
kwargs_to_ignore = [
"top_p",
"tools",
"frequency_penalty",
"presence_penalty",
"logit_bias",
@@ -392,6 +448,14 @@ class ChatPremAI(BaseChatModel, BaseModel):
except Exception as _:
continue
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def create_prem_retry_decorator(
llm: ChatPremAI,

View File

@@ -3,12 +3,16 @@
from typing import cast
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture
from langchain_community.chat_models import ChatPremAI
from langchain_community.chat_models.premai import _messages_to_prompt_dict
from langchain_community.chat_models.premai import (
SINGLE_TOOL_PROMPT_TEMPLATE,
TOOL_PROMPT_HEADER,
_messages_to_prompt_dict,
)
@pytest.mark.requires("premai")
@@ -36,13 +40,20 @@ def test_messages_to_prompt_dict_with_valid_messages() -> None:
AIMessage(content="AI message #1"),
HumanMessage(content="User message #2"),
AIMessage(content="AI message #2"),
ToolMessage(content="Tool Message #1", tool_call_id="test_tool"),
AIMessage(content="AI message #3"),
]
)
expected_tool_message = SINGLE_TOOL_PROMPT_TEMPLATE.format(
tool_id="test_tool", tool_response="Tool Message #1"
)
expected = [
{"role": "user", "content": "User message #1"},
{"role": "assistant", "content": "AI message #1"},
{"role": "user", "content": "User message #2"},
{"role": "assistant", "content": "AI message #2"},
{"role": "assistant", "content": "AI message #3"},
{"role": "user", "content": TOOL_PROMPT_HEADER + expected_tool_message},
]
assert system_message == "System Prompt"