This commit is contained in:
Bagatur 2023-08-09 14:44:29 -07:00
commit 05cdd22c39
15 changed files with 280 additions and 50 deletions

View File

@ -12,7 +12,7 @@ Here are the agents available in LangChain.
### [Zero-shot ReAct](/docs/modules/agents/agent_types/react.html) ### [Zero-shot ReAct](/docs/modules/agents/agent_types/react.html)
This agent uses the [ReAct](https://arxiv.org/pdf/2205.00445.pdf) framework to determine which tool to use This agent uses the [ReAct](https://arxiv.org/pdf/2210.03629) framework to determine which tool to use
based solely on the tool's description. Any number of tools can be provided. based solely on the tool's description. Any number of tools can be provided.
This agent requires that a description is provided for each tool. This agent requires that a description is provided for each tool.

View File

@ -62,9 +62,11 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
_config: Dict[str, Any] = dict(config) if config else {} config = config or {}
_config.pop("_locals", None) config_kwargs: Dict = {
return self(input, **_config, **kwargs) k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return self(input, **config_kwargs, **kwargs)
async def ainvoke( async def ainvoke(
self, self,
@ -77,10 +79,11 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs) None, partial(self.invoke, input, config, **kwargs)
) )
config = config or {}
_config: Dict[str, Any] = dict(config) if config else {} config_kwargs: Dict = {
_config.pop("_locals", None) k: config.get(k) for k in ("callbacks", "tags", "metadata")
return await self.acall(input, **_config, **kwargs) }
return await self.acall(input, **config_kwargs, **kwargs)
memory: Optional[BaseMemory] = None memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None. """Optional memory object. Defaults to None.

View File

@ -103,14 +103,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessageChunk: ) -> BaseMessageChunk:
_config: Dict[str, Any] = dict(config or {}) config = config or {}
_config.pop("_locals", None) config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return cast( return cast(
BaseMessageChunk, BaseMessageChunk,
cast( cast(
ChatGeneration, ChatGeneration,
self.generate_prompt( self.generate_prompt(
[self._convert_input(input)], stop=stop, **_config, **kwargs [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
).generations[0][0], ).generations[0][0],
).message, ).message,
) )
@ -129,10 +131,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs) None, partial(self.invoke, input, config, stop=stop, **kwargs)
) )
_config: Dict[str, Any] = dict(config or {}) config = config or {}
_config.pop("_locals", None) config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **_config, **kwargs [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
) )
return cast( return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message

View File

@ -1,10 +1,9 @@
"""Loads local airbyte json files.""" """Loads local airbyte json files."""
from typing import Any, Callable, Iterator, List, Mapping, Optional from typing import Any, Callable, Iterator, List, Mapping, Optional
from libs.langchain.langchain.utils.utils import guard_import
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.base import BaseLoader
from langchain.utils.utils import guard_import
RecordHandler = Callable[[Any, Optional[str]], Document] RecordHandler = Callable[[Any, Optional[str]], Document]

View File

@ -219,10 +219,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
_config: Dict[str, Any] = dict(config or {}) config = config or {}
_config.pop("_locals", None) config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
result = self.generate_prompt( result = self.generate_prompt(
[self._convert_input(input)], stop=stop, **_config, **kwargs [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
) )
return result.generations[0][0].text return result.generations[0][0].text
@ -240,10 +242,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
None, partial(self.invoke, input, config, stop=stop, **kwargs) None, partial(self.invoke, input, config, stop=stop, **kwargs)
) )
_config: Dict[str, Any] = dict(config or {}) config = config or {}
_config.pop("_locals", None) config_kwargs: Dict = {
k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
llm_result = await self.agenerate_prompt( llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **_config, **kwargs [self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
) )
return llm_result.generations[0][0].text return llm_result.generations[0][0].text

View File

@ -0,0 +1,46 @@
from operator import itemgetter
from typing import Any, Callable, List, Mapping, Optional, Union
from typing_extensions import TypedDict
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
class OpenAIFunction(TypedDict):
"""A function description for ChatOpenAI"""
name: str
"""The name of the function."""
description: str
"""The description of the function."""
parameters: dict
"""The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
"""A runnable that routes to the selected function."""
functions: Optional[List[OpenAIFunction]]
def __init__(
self,
runnables: Mapping[
str,
Union[
Runnable[dict, Any],
Callable[[dict], Any],
],
],
functions: Optional[List[OpenAIFunction]] = None,
):
if functions is not None:
assert len(functions) == len(runnables)
assert all(func["name"] in runnables for func in functions)
router = (
JsonOutputFunctionsParser(args_only=False)
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
| RouterRunnable(runnables)
)
super().__init__(bound=router, kwargs={}, functions=functions)

View File

@ -107,9 +107,11 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
def invoke( def invoke(
self, input: str, config: Optional[RunnableConfig] = None self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]: ) -> List[Document]:
_config: Dict[str, Any] = dict(config or {}) config = config or {}
_config.pop("_locals", None) config_kwargs: Dict = {
return self.get_relevant_documents(input, **_config) k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return self.get_relevant_documents(input, **config_kwargs)
async def ainvoke( async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None self, input: str, config: Optional[RunnableConfig] = None
@ -118,9 +120,11 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
# If the retriever doesn't implement async, use default implementation # If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config) return await super().ainvoke(input, config)
_config: Dict[str, Any] = dict(config or {}) config = config or {}
_config.pop("_locals", None) config_kwargs: Dict = {
return await self.aget_relevant_documents(input, **_config) k: config.get(k) for k in ("callbacks", "tags", "metadata")
}
return await self.aget_relevant_documents(input, **config_kwargs)
@abstractmethod @abstractmethod
def _get_relevant_documents( def _get_relevant_documents(

View File

@ -1229,7 +1229,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
class RunnableBinding(Serializable, Runnable[Input, Output]): class RunnableBinding(Serializable, Runnable[Input, Output]):
""" """
A runnable that binds a runnable to a set of kwargs. A runnable that delegates calls to another runnable with a set of kwargs.
""" """
bound: Runnable[Input, Output] bound: Runnable[Input, Output]
@ -1314,8 +1314,15 @@ class RouterRunnable(
runnables: Mapping[str, Runnable[Input, Output]] runnables: Mapping[str, Runnable[Input, Output]]
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None: def __init__(
super().__init__(runnables=runnables) self,
runnables: Mapping[
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
) -> None:
super().__init__(
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@ -502,6 +502,18 @@ def _construct_run_evaluator(
return run_evaluator return run_evaluator
def _get_keys(
config: RunEvalConfig,
run_inputs: Optional[List[str]],
run_outputs: Optional[List[str]],
example_outputs: Optional[List[str]],
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
input_key = _determine_input_key(config, run_inputs)
prediction_key = _determine_prediction_key(config, run_outputs)
reference_key = _determine_reference_key(config, example_outputs)
return input_key, prediction_key, reference_key
def _load_run_evaluators( def _load_run_evaluators(
config: RunEvalConfig, config: RunEvalConfig,
run_type: str, run_type: str,
@ -521,9 +533,13 @@ def _load_run_evaluators(
""" """
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0) eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
run_evaluators = [] run_evaluators = []
input_key = _determine_input_key(config, run_inputs) input_key, prediction_key, reference_key = None, None, None
prediction_key = _determine_prediction_key(config, run_outputs) if config.evaluators or any(
reference_key = _determine_reference_key(config, example_outputs) [isinstance(e, EvaluatorType) for e in config.evaluators]
):
input_key, prediction_key, reference_key = _get_keys(
config, run_inputs, run_outputs, example_outputs
)
for eval_config in config.evaluators: for eval_config in config.evaluators:
run_evaluator = _construct_run_evaluator( run_evaluator = _construct_run_evaluator(
eval_config, eval_config,
@ -1078,15 +1094,15 @@ def _run_on_examples(
A dictionary mapping example ids to the model outputs. A dictionary mapping example ids to the model outputs.
""" """
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory) wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory) project_name = _get_project_name(project_name, wrapped_model)
tracer = LangChainTracer( tracer = LangChainTracer(
project_name=project_name, client=client, use_threading=False project_name=project_name, client=client, use_threading=False
) )
run_evaluators, examples = _setup_evaluation( run_evaluators, examples = _setup_evaluation(
llm_or_chain_factory, examples, evaluation, data_type wrapped_model, examples, evaluation, data_type
) )
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper) examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
evalution_handler = EvaluatorCallbackHandler( evalution_handler = EvaluatorCallbackHandler(
evaluators=run_evaluators or [], evaluators=run_evaluators or [],
client=client, client=client,
@ -1095,7 +1111,7 @@ def _run_on_examples(
for i, example in enumerate(examples): for i, example in enumerate(examples):
result = _run_llm_or_chain( result = _run_llm_or_chain(
example, example,
llm_or_chain_factory, wrapped_model,
num_repetitions, num_repetitions,
tags=tags, tags=tags,
callbacks=callbacks, callbacks=callbacks,
@ -1118,8 +1134,8 @@ def _prepare_eval_run(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: Optional[str], project_name: Optional[str],
) -> Tuple[MCF, str, Dataset, Iterator[Example]]: ) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name) wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
project_name = _get_project_name(project_name, llm_or_chain_factory) project_name = _get_project_name(project_name, wrapped_model)
try: try:
project = client.create_project(project_name) project = client.create_project(project_name)
except ValueError as e: except ValueError as e:
@ -1134,7 +1150,7 @@ def _prepare_eval_run(
) )
dataset = client.read_dataset(dataset_name=dataset_name) dataset = client.read_dataset(dataset_name=dataset_name)
examples = client.list_examples(dataset_id=str(dataset.id)) examples = client.list_examples(dataset_id=str(dataset.id))
return llm_or_chain_factory, project_name, dataset, examples return wrapped_model, project_name, dataset, examples
async def arun_on_dataset( async def arun_on_dataset(
@ -1260,13 +1276,13 @@ async def arun_on_dataset(
evaluation=evaluation_config, evaluation=evaluation_config,
) )
""" # noqa: E501 """ # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run( wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name client, dataset_name, llm_or_chain_factory, project_name
) )
results = await _arun_on_examples( results = await _arun_on_examples(
client, client,
examples, examples,
llm_or_chain_factory, wrapped_model,
concurrency_level=concurrency_level, concurrency_level=concurrency_level,
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,
@ -1427,14 +1443,14 @@ def run_on_dataset(
evaluation=evaluation_config, evaluation=evaluation_config,
) )
""" # noqa: E501 """ # noqa: E501
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run( wrapped_model, project_name, dataset, examples = _prepare_eval_run(
client, dataset_name, llm_or_chain_factory, project_name client, dataset_name, llm_or_chain_factory, project_name
) )
if concurrency_level in (0, 1): if concurrency_level in (0, 1):
results = _run_on_examples( results = _run_on_examples(
client, client,
examples, examples,
llm_or_chain_factory, wrapped_model,
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,
verbose=verbose, verbose=verbose,
@ -1448,7 +1464,7 @@ def run_on_dataset(
coro = _arun_on_examples( coro = _arun_on_examples(
client, client,
examples, examples,
llm_or_chain_factory, wrapped_model,
concurrency_level=concurrency_level, concurrency_level=concurrency_level,
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
project_name=project_name, project_name=project_name,

View File

@ -203,7 +203,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
config = config or {} config = config or {}
return self.run(input, **config, **kwargs) return self.run(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
async def ainvoke( async def ainvoke(
self, self,
@ -216,7 +222,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
return super().ainvoke(input, config, **kwargs) return super().ainvoke(input, config, **kwargs)
config = config or {} config = config or {}
return await self.arun(input, **config, **kwargs) return await self.arun(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
**kwargs,
)
# --- Tool --- # --- Tool ---

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "langchain" name = "langchain"
version = "0.0.259" version = "0.0.260"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
authors = [] authors = []
license = "MIT" license = "MIT"

View File

@ -0,0 +1,9 @@
"""Test the airbyte document loader.
Light test to ensure that the airbyte document loader can be imported.
"""
def test_airbyte_import() -> None:
"""Test that the airbyte document loader can be imported."""
from langchain.document_loaders import airbyte # noqa

View File

@ -0,0 +1,31 @@
# serializer version: 1
# name: test_openai_functions_router
list([
dict({
'description': 'Sends the draft for revision.',
'name': 'revise',
'parameters': dict({
'properties': dict({
'notes': dict({
'description': "The editor's notes to guide the revision.",
'type': 'string',
}),
}),
'type': 'object',
}),
}),
dict({
'description': 'Accepts the draft.',
'name': 'accept',
'parameters': dict({
'properties': dict({
'draft': dict({
'description': 'The draft to accept.',
'type': 'string',
}),
}),
'type': 'object',
}),
}),
])
# ---

View File

@ -0,0 +1,95 @@
from typing import Any, List, Optional
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
from langchain.schema import ChatResult
from langchain.schema.messages import AIMessage, BaseMessage
from langchain.schema.output import ChatGeneration
class FakeChatOpenAI(BaseChatModel):
@property
def _llm_type(self) -> str:
return "fake-openai-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "accept",
"arguments": '{\n "draft": "turtles"\n}',
}
},
)
)
]
)
def test_openai_functions_router(
snapshot: SnapshotAssertion, mocker: MockerFixture
) -> None:
revise = mocker.Mock(
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!'
)
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!')
router = OpenAIFunctionsRouter(
{
"revise": revise,
"accept": accept,
},
functions=[
{
"name": "revise",
"description": "Sends the draft for revision.",
"parameters": {
"type": "object",
"properties": {
"notes": {
"type": "string",
"description": "The editor's notes to guide the revision.",
},
},
},
},
{
"name": "accept",
"description": "Accepts the draft.",
"parameters": {
"type": "object",
"properties": {
"draft": {
"type": "string",
"description": "The draft to accept.",
},
},
},
},
],
)
model = FakeChatOpenAI()
chain = model.bind(functions=router.functions) | router
assert router.functions == snapshot
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"
revise.assert_not_called()
accept.assert_called_once_with({"draft": "turtles"})