Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
1a640f07e7 fmt 2024-04-19 16:46:56 -07:00
Bagatur
46d3b07ad8 rfc: bind standard serialized tools 2024-04-19 16:34:48 -07:00
3 changed files with 55 additions and 7 deletions

View File

@@ -13,9 +13,11 @@ from typing import (
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
TypedDict,
Union,
cast,
)
@@ -50,14 +52,14 @@ from langchain_core.outputs import (
RunInfo,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tools import BaseTool
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
from langchain_core.utils.function_calling import convert_to_openai_tool
if TYPE_CHECKING:
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
@@ -163,10 +165,15 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
**self._prep_kwargs(**kwargs),
).generations[0][0],
).message
def _prep_kwargs(
self, __lc_serialized_tools__: Sequence[_SerializedTool] = (), **kwargs: Any
) -> Dict[str, Any]:
return kwargs
async def ainvoke(
self,
input: LanguageModelInput,
@@ -906,11 +913,52 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
tools: Sequence[_ToolLike],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self._bind_tools(
tools,
__lc_serialized_tools__=self._as_lc_serialized_tools(tools),
**kwargs,
)
def _bind_tools(
self,
tools: Sequence[_ToolLike],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError()
def _as_lc_serialized_tools(
self, tools: Sequence[_ToolLike]
) -> List[_SerializedTool]:
ser_tools = []
for tool in tools:
ser = convert_to_openai_tool(tool)["function"]
ser_tools.append(
_SerializedTool(
title=ser.get("name", ""),
description=ser.get("description", ""),
type="object",
properties=ser["parameters"].get("properties", {}),
required=ser["parameters"].get("required", []),
)
)
return ser_tools
class _SerializedTool(TypedDict):
"""JSONSchema representing a tool"""
title: str
description: str
type: Literal["object"]
properties: Dict[str, Any]
required: List[str]
_ToolLike = Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]
class SimpleChatModel(BaseChatModel):
"""A simplified implementation for a chat model to inherit from."""

View File

@@ -4431,7 +4431,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
kwargs: optional kwargs to pass to the underlying runnable, when running
the underlying runnable (e.g., via `invoke`, `batch`,
`transform`, or `stream` or async variants)
config: config_factories:
config:
config_factories: optional list of config factories to apply to the
custom_input_type: Specify to override the input type of the underlying
runnable with a custom type.

View File

@@ -803,7 +803,7 @@ class ChatOpenAI(BaseChatModel):
**kwargs,
)
def bind_tools(
def _bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,