mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
anthropic[minor]: tool use (#20016)
This commit is contained in:
@@ -1,38 +1,13 @@
|
||||
import json
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
|
||||
@@ -168,143 +143,16 @@ def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
|
||||
return [_xml_to_function_call(invoke, tools) for invoke in invokes]
|
||||
|
||||
|
||||
@beta()
|
||||
@deprecated(
|
||||
"0.1.5",
|
||||
removal="0.2.0",
|
||||
alternative="ChatAnthropic",
|
||||
message=(
|
||||
"Tool-calling is now officially supported by the Anthropic API so this "
|
||||
"workaround is no longer needed."
|
||||
),
|
||||
)
|
||||
class ChatAnthropicTools(ChatAnthropic):
|
||||
"""Chat model for interacting with Anthropic functions."""
|
||||
|
||||
_xmllib: Any = Field(default=None)
|
||||
|
||||
@root_validator()
|
||||
def check_xml_lib(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
# do this as an optional dep for temporary nature of this feature
|
||||
import defusedxml.ElementTree as DET # type: ignore
|
||||
|
||||
values["_xmllib"] = DET
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import defusedxml python package. "
|
||||
"Please install it using `pip install defusedxml`"
|
||||
)
|
||||
return values
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tools to the chat model."""
|
||||
formatted_tools = [convert_to_openai_function(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
if kwargs:
|
||||
raise ValueError("kwargs are not supported for with_structured_output")
|
||||
llm = self.bind_tools([schema])
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
# schema is pydantic
|
||||
return llm | PydanticToolsParser(tools=[schema], first_tool_only=True)
|
||||
else:
|
||||
# schema is dict
|
||||
key_name = convert_to_openai_function(schema)["name"]
|
||||
return llm | JsonOutputKeyToolsParser(
|
||||
key_name=key_name, first_tool_only=True
|
||||
)
|
||||
|
||||
def _format_params(
|
||||
self,
|
||||
*,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict:
|
||||
tools: List[Dict] = kwargs.get("tools", None)
|
||||
# experimental tools are sent in as part of system prompt, so if
|
||||
# both are set, turn system prompt into tools + system prompt (tools first)
|
||||
if tools:
|
||||
tool_system = get_system_message(tools)
|
||||
|
||||
if messages[0].type == "system":
|
||||
sys_content = messages[0].content
|
||||
new_sys_content = f"{tool_system}\n\n{sys_content}"
|
||||
messages = [SystemMessage(content=new_sys_content), *messages[1:]]
|
||||
else:
|
||||
messages = [SystemMessage(content=tool_system), *messages]
|
||||
|
||||
return super()._format_params(messages=messages, stop=stop, **kwargs)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
# streaming not supported for functions
|
||||
result = self._generate(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
to_yield = result.generations[0]
|
||||
chunk = ChatGenerationChunk(
|
||||
message=cast(BaseMessageChunk, to_yield.message),
|
||||
generation_info=to_yield.generation_info,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
cast(str, to_yield.message.content), chunk=chunk
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
# streaming not supported for functions
|
||||
result = await self._agenerate(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
to_yield = result.generations[0]
|
||||
chunk = ChatGenerationChunk(
|
||||
message=cast(BaseMessageChunk, to_yield.message),
|
||||
generation_info=to_yield.generation_info,
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
cast(str, to_yield.message.content), chunk=chunk
|
||||
)
|
||||
yield chunk
|
||||
|
||||
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
|
||||
"""Format the output of the model, parsing xml as a tool call."""
|
||||
text = data.content[0].text
|
||||
tools = kwargs.get("tools", None)
|
||||
|
||||
additional_kwargs: Dict[str, Any] = {}
|
||||
|
||||
if tools:
|
||||
# parse out the xml from the text
|
||||
try:
|
||||
# get everything between <function_calls> and </function_calls>
|
||||
start = text.find("<function_calls>")
|
||||
end = text.find("</function_calls>") + len("</function_calls>")
|
||||
xml_text = text[start:end]
|
||||
|
||||
xml = self._xmllib.fromstring(xml_text)
|
||||
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml, tools)
|
||||
text = ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=AIMessage(content=text, additional_kwargs=additional_kwargs)
|
||||
)
|
||||
],
|
||||
llm_output=data,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user