mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 06:53:59 +00:00
[Community] PremAI Tool Calling Functionality (#23931)
This PR is under WIP and adds the following functionalities: - [X] Supports tool calling across the langchain ecosystem. (However streaming is not supported) - [X] Update documentation
This commit is contained in:
@@ -12,6 +12,7 @@ from typing import (
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@@ -20,6 +21,7 @@ from typing import (
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from langchain_core.messages import (
|
||||
@@ -33,6 +35,7 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
@@ -41,7 +44,10 @@ from langchain_core.pydantic_v1 import (
|
||||
Field,
|
||||
SecretStr,
|
||||
)
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env, pre_init
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from premai.api.chat_completions.v1_chat_completions_create import (
|
||||
@@ -51,6 +57,19 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOOL_PROMPT_HEADER = """
|
||||
Given the set of tools you used and the response, provide the final answer\n
|
||||
"""
|
||||
|
||||
INTERMEDIATE_TOOL_RESULT_TEMPLATE = """
|
||||
{json}
|
||||
"""
|
||||
|
||||
SINGLE_TOOL_PROMPT_TEMPLATE = """
|
||||
tool id: {tool_id}
|
||||
tool_response: {tool_response}
|
||||
"""
|
||||
|
||||
|
||||
class ChatPremAPIError(Exception):
|
||||
"""Error with the `PremAI` API."""
|
||||
@@ -91,8 +110,22 @@ def _response_to_result(
|
||||
raise ChatPremAPIError(f"ChatResponse must have a content: {content}")
|
||||
|
||||
if role == "assistant":
|
||||
tool_calls = choice.message["tool_calls"]
|
||||
if tool_calls is None:
|
||||
tools = []
|
||||
else:
|
||||
tools = [
|
||||
{
|
||||
"id": tool_call["id"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"args": tool_call["function"]["arguments"],
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
generations.append(
|
||||
ChatGeneration(text=content, message=AIMessage(content=content))
|
||||
ChatGeneration(
|
||||
text=content, message=AIMessage(content=content, tool_calls=tools)
|
||||
)
|
||||
)
|
||||
elif role == "user":
|
||||
generations.append(
|
||||
@@ -156,41 +189,65 @@ def _messages_to_prompt_dict(
|
||||
system_prompt: Optional[str] = None
|
||||
examples_and_messages: List[Dict[str, Any]] = []
|
||||
|
||||
if template_id is not None:
|
||||
params: Dict[str, str] = {}
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, SystemMessage):
|
||||
system_prompt = str(input_msg.content)
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, SystemMessage):
|
||||
system_prompt = str(input_msg.content)
|
||||
|
||||
elif isinstance(input_msg, HumanMessage):
|
||||
if template_id is None:
|
||||
examples_and_messages.append(
|
||||
{"role": "user", "content": str(input_msg.content)}
|
||||
)
|
||||
else:
|
||||
params: Dict[str, str] = {}
|
||||
assert (input_msg.id is not None) and (input_msg.id != ""), ValueError(
|
||||
"When using prompt template there should be id associated ",
|
||||
"with each HumanMessage",
|
||||
)
|
||||
params[str(input_msg.id)] = str(input_msg.content)
|
||||
|
||||
examples_and_messages.append(
|
||||
{"role": "user", "template_id": template_id, "params": params}
|
||||
)
|
||||
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, AIMessage):
|
||||
examples_and_messages.append(
|
||||
{"role": "assistant", "content": str(input_msg.content)}
|
||||
{"role": "user", "template_id": template_id, "params": params}
|
||||
)
|
||||
else:
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, SystemMessage):
|
||||
system_prompt = str(input_msg.content)
|
||||
elif isinstance(input_msg, HumanMessage):
|
||||
examples_and_messages.append(
|
||||
{"role": "user", "content": str(input_msg.content)}
|
||||
)
|
||||
elif isinstance(input_msg, AIMessage):
|
||||
elif isinstance(input_msg, AIMessage):
|
||||
if input_msg.tool_calls is None or len(input_msg.tool_calls) == 0:
|
||||
examples_and_messages.append(
|
||||
{"role": "assistant", "content": str(input_msg.content)}
|
||||
)
|
||||
else:
|
||||
raise ChatPremAPIError("No such role explicitly exists")
|
||||
ai_msg_to_json = {
|
||||
"id": input_msg.id,
|
||||
"content": input_msg.content,
|
||||
"response_metadata": input_msg.response_metadata,
|
||||
"tool_calls": input_msg.tool_calls,
|
||||
}
|
||||
examples_and_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": INTERMEDIATE_TOOL_RESULT_TEMPLATE.format(
|
||||
json=ai_msg_to_json,
|
||||
),
|
||||
}
|
||||
)
|
||||
elif isinstance(input_msg, ToolMessage):
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ChatPremAPIError("No such role explicitly exists")
|
||||
|
||||
# do a seperate search for tool calls
|
||||
tool_prompt = ""
|
||||
for input_msg in input_messages:
|
||||
if isinstance(input_msg, ToolMessage):
|
||||
tool_id = input_msg.tool_call_id
|
||||
tool_result = input_msg.content
|
||||
tool_prompt += SINGLE_TOOL_PROMPT_TEMPLATE.format(
|
||||
tool_id=tool_id, tool_response=tool_result
|
||||
)
|
||||
if tool_prompt != "":
|
||||
prompt = TOOL_PROMPT_HEADER
|
||||
prompt += tool_prompt
|
||||
examples_and_messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return system_prompt, examples_and_messages
|
||||
|
||||
|
||||
@@ -289,7 +346,6 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs_to_ignore = [
|
||||
"top_p",
|
||||
"tools",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"logit_bias",
|
||||
@@ -392,6 +448,14 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
except Exception as _:
|
||||
continue
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
|
||||
def create_prem_retry_decorator(
|
||||
llm: ChatPremAI,
|
||||
|
@@ -3,12 +3,16 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
from langchain_community.chat_models import ChatPremAI
|
||||
from langchain_community.chat_models.premai import _messages_to_prompt_dict
|
||||
from langchain_community.chat_models.premai import (
|
||||
SINGLE_TOOL_PROMPT_TEMPLATE,
|
||||
TOOL_PROMPT_HEADER,
|
||||
_messages_to_prompt_dict,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("premai")
|
||||
@@ -36,13 +40,20 @@ def test_messages_to_prompt_dict_with_valid_messages() -> None:
|
||||
AIMessage(content="AI message #1"),
|
||||
HumanMessage(content="User message #2"),
|
||||
AIMessage(content="AI message #2"),
|
||||
ToolMessage(content="Tool Message #1", tool_call_id="test_tool"),
|
||||
AIMessage(content="AI message #3"),
|
||||
]
|
||||
)
|
||||
expected_tool_message = SINGLE_TOOL_PROMPT_TEMPLATE.format(
|
||||
tool_id="test_tool", tool_response="Tool Message #1"
|
||||
)
|
||||
expected = [
|
||||
{"role": "user", "content": "User message #1"},
|
||||
{"role": "assistant", "content": "AI message #1"},
|
||||
{"role": "user", "content": "User message #2"},
|
||||
{"role": "assistant", "content": "AI message #2"},
|
||||
{"role": "assistant", "content": "AI message #3"},
|
||||
{"role": "user", "content": TOOL_PROMPT_HEADER + expected_tool_message},
|
||||
]
|
||||
|
||||
assert system_message == "System Prompt"
|
||||
|
Reference in New Issue
Block a user