mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
Implement a router for openai functions (#8589)
This commit is contained in:
parent
a6e6e9bb86
commit
808248049d
@ -62,7 +62,14 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
return self(input, **(config or {}), **kwargs)
|
||||
config = config or {}
|
||||
return self(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@ -76,7 +83,14 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
None, partial(self.invoke, input, config, **kwargs)
|
||||
)
|
||||
|
||||
return await self.acall(input, **(config or {}), **kwargs)
|
||||
config = config or {}
|
||||
return await self.acall(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
"""Optional memory object. Defaults to None.
|
||||
|
@ -103,12 +103,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
config = config or {}
|
||||
return cast(
|
||||
BaseMessageChunk,
|
||||
cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message,
|
||||
)
|
||||
@ -127,8 +133,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
return cast(
|
||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
@ -219,9 +219,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
config = config or {}
|
||||
return (
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
.generations[0][0]
|
||||
.text
|
||||
@ -241,8 +247,14 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)], stop=stop, **(config or {}), **kwargs
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
return llm_result.generations[0][0].text
|
||||
|
||||
|
0
libs/langchain/langchain/runnables/__init__.py
Normal file
0
libs/langchain/langchain/runnables/__init__.py
Normal file
46
libs/langchain/langchain/runnables/openai_functions.py
Normal file
46
libs/langchain/langchain/runnables/openai_functions.py
Normal file
@ -0,0 +1,46 @@
|
||||
from operator import itemgetter
|
||||
from typing import Any, Callable, List, Mapping, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.schema.output import ChatGeneration
|
||||
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
||||
|
||||
|
||||
class OpenAIFunction(TypedDict):
|
||||
"""A function description for ChatOpenAI"""
|
||||
|
||||
name: str
|
||||
"""The name of the function."""
|
||||
description: str
|
||||
"""The description of the function."""
|
||||
parameters: dict
|
||||
"""The parameters to the function."""
|
||||
|
||||
|
||||
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
|
||||
"""A runnable that routes to the selected function."""
|
||||
|
||||
functions: Optional[List[OpenAIFunction]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnables: Mapping[
|
||||
str,
|
||||
Union[
|
||||
Runnable[dict, Any],
|
||||
Callable[[dict], Any],
|
||||
],
|
||||
],
|
||||
functions: Optional[List[OpenAIFunction]] = None,
|
||||
):
|
||||
if functions is not None:
|
||||
assert len(functions) == len(runnables)
|
||||
assert all(func["name"] in runnables for func in functions)
|
||||
router = (
|
||||
JsonOutputFunctionsParser(args_only=False)
|
||||
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
|
||||
| RouterRunnable(runnables)
|
||||
)
|
||||
super().__init__(bound=router, kwargs={}, functions=functions)
|
@ -107,7 +107,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
) -> List[Document]:
|
||||
return self.get_relevant_documents(input, **(config or {}))
|
||||
config = config or {}
|
||||
return self.get_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None
|
||||
@ -116,7 +122,13 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
||||
# If the retriever doesn't implement async, use default implementation
|
||||
return await super().ainvoke(input, config)
|
||||
|
||||
return await self.aget_relevant_documents(input, **(config or {}))
|
||||
config = config or {}
|
||||
return await self.aget_relevant_documents(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
|
@ -1254,7 +1254,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
|
||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that binds a runnable to a set of kwargs.
|
||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||
"""
|
||||
|
||||
bound: Runnable[Input, Output]
|
||||
@ -1339,8 +1339,15 @@ class RouterRunnable(
|
||||
|
||||
runnables: Mapping[str, Runnable[Input, Output]]
|
||||
|
||||
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
|
||||
super().__init__(runnables=runnables)
|
||||
def __init__(
|
||||
self,
|
||||
runnables: Mapping[
|
||||
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
|
||||
],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -203,7 +203,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
config = config or {}
|
||||
return self.run(input, **config, **kwargs)
|
||||
return self.run(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@ -216,7 +222,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
|
||||
return super().ainvoke(input, config, **kwargs)
|
||||
|
||||
config = config or {}
|
||||
return await self.arun(input, **config, **kwargs)
|
||||
return await self.arun(
|
||||
input,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# --- Tool ---
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
# serializer version: 1
|
||||
# name: test_openai_functions_router
|
||||
list([
|
||||
dict({
|
||||
'description': 'Sends the draft for revision.',
|
||||
'name': 'revise',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'notes': dict({
|
||||
'description': "The editor's notes to guide the revision.",
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
dict({
|
||||
'description': 'Accepts the draft.',
|
||||
'name': 'accept',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'draft': dict({
|
||||
'description': 'The draft to accept.',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
])
|
||||
# ---
|
@ -0,0 +1,95 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
|
||||
from langchain.schema import ChatResult
|
||||
from langchain.schema.messages import AIMessage, BaseMessage
|
||||
from langchain.schema.output import ChatGeneration
|
||||
|
||||
|
||||
class FakeChatOpenAI(BaseChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-openai-chat-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "accept",
|
||||
"arguments": '{\n "draft": "turtles"\n}',
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_openai_functions_router(
|
||||
snapshot: SnapshotAssertion, mocker: MockerFixture
|
||||
) -> None:
|
||||
revise = mocker.Mock(
|
||||
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!'
|
||||
)
|
||||
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!')
|
||||
|
||||
router = OpenAIFunctionsRouter(
|
||||
{
|
||||
"revise": revise,
|
||||
"accept": accept,
|
||||
},
|
||||
functions=[
|
||||
{
|
||||
"name": "revise",
|
||||
"description": "Sends the draft for revision.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"notes": {
|
||||
"type": "string",
|
||||
"description": "The editor's notes to guide the revision.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "accept",
|
||||
"description": "Accepts the draft.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"draft": {
|
||||
"type": "string",
|
||||
"description": "The draft to accept.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
model = FakeChatOpenAI()
|
||||
|
||||
chain = model.bind(functions=router.functions) | router
|
||||
|
||||
assert router.functions == snapshot
|
||||
|
||||
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"
|
||||
|
||||
revise.assert_not_called()
|
||||
accept.assert_called_once_with({"draft": "turtles"})
|
Loading…
Reference in New Issue
Block a user