mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-17 13:00:49 +00:00
* Making `FakeToolCallingModel` generic on its `structured_response` doesn't help anywhere in typing. * There are more than 120 references of `FakeToolCallingModel` in the code where you get ` error: Need type annotation for "model" [var-annotated]` because mypy can't resolve the generic type (we don't see them atm because they are in files temporarily excluded from mypy checking). We would need to explicitly type them to `FakeToolCallingModel[Any]` Co-authored-by: Mason Daugherty <mason@langchain.dev>
112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
import json
|
|
from collections.abc import Callable, Sequence
|
|
from dataclasses import asdict, is_dataclass
|
|
from typing import (
|
|
Any,
|
|
Literal,
|
|
)
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
ToolCall,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
|
from langchain_core.runnables import Runnable
|
|
from langchain_core.tools import BaseTool
|
|
from pydantic import BaseModel
|
|
from typing_extensions import override
|
|
|
|
|
|
class FakeToolCallingModel(BaseChatModel):
|
|
tool_calls: list[list[ToolCall]] | list[list[dict[str, Any]]] | None = None
|
|
structured_response: Any | None = None
|
|
index: int = 0
|
|
tool_style: Literal["openai", "anthropic"] = "openai"
|
|
|
|
def _generate(
|
|
self,
|
|
messages: list[BaseMessage],
|
|
stop: list[str] | None = None,
|
|
run_manager: CallbackManagerForLLMRun | None = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Top Level call."""
|
|
is_native = kwargs.get("response_format")
|
|
|
|
if self.tool_calls:
|
|
if is_native:
|
|
tool_calls = (
|
|
self.tool_calls[self.index] if self.index < len(self.tool_calls) else []
|
|
)
|
|
else:
|
|
tool_calls = self.tool_calls[self.index % len(self.tool_calls)]
|
|
else:
|
|
tool_calls = []
|
|
|
|
if is_native and not tool_calls:
|
|
if isinstance(self.structured_response, BaseModel):
|
|
content_obj = self.structured_response.model_dump()
|
|
elif is_dataclass(self.structured_response) and not isinstance(
|
|
self.structured_response, type
|
|
):
|
|
content_obj = asdict(self.structured_response)
|
|
elif isinstance(self.structured_response, dict):
|
|
content_obj = self.structured_response
|
|
message = AIMessage(content=json.dumps(content_obj), id=str(self.index))
|
|
else:
|
|
messages_string = "-".join([m.text for m in messages])
|
|
message = AIMessage(
|
|
content=messages_string,
|
|
id=str(self.index),
|
|
tool_calls=tool_calls.copy(),
|
|
)
|
|
self.index += 1
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "fake-tool-call-model"
|
|
|
|
@override
|
|
def bind_tools(
|
|
self,
|
|
tools: Sequence[dict[str, Any] | type | Callable[..., Any] | BaseTool],
|
|
*,
|
|
tool_choice: str | None = None,
|
|
**kwargs: Any,
|
|
) -> Runnable[LanguageModelInput, AIMessage]:
|
|
if len(tools) == 0:
|
|
msg = "Must provide at least one tool"
|
|
raise ValueError(msg)
|
|
|
|
tool_dicts = []
|
|
for tool in tools:
|
|
if isinstance(tool, dict):
|
|
tool_dicts.append(tool)
|
|
continue
|
|
if not isinstance(tool, BaseTool):
|
|
msg = "Only BaseTool and dict is supported by FakeToolCallingModel.bind_tools"
|
|
raise TypeError(msg)
|
|
|
|
# NOTE: this is a simplified tool spec for testing purposes only
|
|
if self.tool_style == "openai":
|
|
tool_dicts.append(
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool.name,
|
|
},
|
|
}
|
|
)
|
|
elif self.tool_style == "anthropic":
|
|
tool_dicts.append(
|
|
{
|
|
"name": tool.name,
|
|
}
|
|
)
|
|
|
|
return self.bind(tools=tool_dicts, **kwargs)
|