mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
merge
This commit is contained in:
commit
05cdd22c39
@ -12,7 +12,7 @@ Here are the agents available in LangChain.
|
|||||||
|
|
||||||
### [Zero-shot ReAct](/docs/modules/agents/agent_types/react.html)
|
### [Zero-shot ReAct](/docs/modules/agents/agent_types/react.html)
|
||||||
|
|
||||||
This agent uses the [ReAct](https://arxiv.org/pdf/2205.00445.pdf) framework to determine which tool to use
|
This agent uses the [ReAct](https://arxiv.org/pdf/2210.03629) framework to determine which tool to use
|
||||||
based solely on the tool's description. Any number of tools can be provided.
|
based solely on the tool's description. Any number of tools can be provided.
|
||||||
This agent requires that a description is provided for each tool.
|
This agent requires that a description is provided for each tool.
|
||||||
|
|
||||||
|
@ -62,9 +62,11 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
_config: Dict[str, Any] = dict(config) if config else {}
|
config = config or {}
|
||||||
_config.pop("_locals", None)
|
config_kwargs: Dict = {
|
||||||
return self(input, **_config, **kwargs)
|
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||||
|
}
|
||||||
|
return self(input, **config_kwargs, **kwargs)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
@ -77,10 +79,11 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
None, partial(self.invoke, input, config, **kwargs)
|
None, partial(self.invoke, input, config, **kwargs)
|
||||||
)
|
)
|
||||||
|
config = config or {}
|
||||||
_config: Dict[str, Any] = dict(config) if config else {}
|
config_kwargs: Dict = {
|
||||||
_config.pop("_locals", None)
|
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||||
return await self.acall(input, **_config, **kwargs)
|
}
|
||||||
|
return await self.acall(input, **config_kwargs, **kwargs)
|
||||||
|
|
||||||
memory: Optional[BaseMemory] = None
|
memory: Optional[BaseMemory] = None
|
||||||
"""Optional memory object. Defaults to None.
|
"""Optional memory object. Defaults to None.
|
||||||
|
@ -103,14 +103,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
_config: Dict[str, Any] = dict(config or {})
|
config = config or {}
|
||||||
_config.pop("_locals", None)
|
config_kwargs: Dict = {
|
||||||
|
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||||
|
}
|
||||||
return cast(
|
return cast(
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
cast(
|
cast(
|
||||||
ChatGeneration,
|
ChatGeneration,
|
||||||
self.generate_prompt(
|
self.generate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
|
||||||
).generations[0][0],
|
).generations[0][0],
|
||||||
).message,
|
).message,
|
||||||
)
|
)
|
||||||
@ -129,10 +131,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
_config: Dict[str, Any] = dict(config or {})
|
config = config or {}
|
||||||
_config.pop("_locals", None)
|
config_kwargs: Dict = {
|
||||||
|
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||||
|
}
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
|
||||||
)
|
)
|
||||||
return cast(
|
return cast(
|
||||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
"""Loads local airbyte json files."""
|
"""Loads local airbyte json files."""
|
||||||
from typing import Any, Callable, Iterator, List, Mapping, Optional
|
from typing import Any, Callable, Iterator, List, Mapping, Optional
|
||||||
|
|
||||||
from libs.langchain.langchain.utils.utils import guard_import
|
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.document_loaders.base import BaseLoader
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
from langchain.utils.utils import guard_import
|
||||||
|
|
||||||
RecordHandler = Callable[[Any, Optional[str]], Document]
|
RecordHandler = Callable[[Any, Optional[str]], Document]
|
||||||
|
|
||||||
|
@ -219,10 +219,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
_config: Dict[str, Any] = dict(config or {})
|
config = config or {}
|
||||||
_config.pop("_locals", None)
|
config_kwargs: Dict = {
|
||||||
|
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||||
|
}
|
||||||
result = self.generate_prompt(
|
result = self.generate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
|
||||||
)
|
)
|
||||||
return result.generations[0][0].text
|
return result.generations[0][0].text
|
||||||
|
|
||||||
@ -240,10 +242,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
_config: Dict[str, Any] = dict(config or {})
|
config = config or {}
|
||||||
_config.pop("_locals", None)
|
config_kwargs: Dict = {
|
||||||
|
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||||
|
}
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)], stop=stop, **_config, **kwargs
|
[self._convert_input(input)], stop=stop, **config_kwargs, **kwargs
|
||||||
)
|
)
|
||||||
return llm_result.generations[0][0].text
|
return llm_result.generations[0][0].text
|
||||||
|
|
||||||
|
0
libs/langchain/langchain/runnables/__init__.py
Normal file
0
libs/langchain/langchain/runnables/__init__.py
Normal file
46
libs/langchain/langchain/runnables/openai_functions.py
Normal file
46
libs/langchain/langchain/runnables/openai_functions.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from operator import itemgetter
|
||||||
|
from typing import Any, Callable, List, Mapping, Optional, Union
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
|
from langchain.schema.output import ChatGeneration
|
||||||
|
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFunction(TypedDict):
|
||||||
|
"""A function description for ChatOpenAI"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""The name of the function."""
|
||||||
|
description: str
|
||||||
|
"""The description of the function."""
|
||||||
|
parameters: dict
|
||||||
|
"""The parameters to the function."""
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
|
||||||
|
"""A runnable that routes to the selected function."""
|
||||||
|
|
||||||
|
functions: Optional[List[OpenAIFunction]]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
runnables: Mapping[
|
||||||
|
str,
|
||||||
|
Union[
|
||||||
|
Runnable[dict, Any],
|
||||||
|
Callable[[dict], Any],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
functions: Optional[List[OpenAIFunction]] = None,
|
||||||
|
):
|
||||||
|
if functions is not None:
|
||||||
|
assert len(functions) == len(runnables)
|
||||||
|
assert all(func["name"] in runnables for func in functions)
|
||||||
|
router = (
|
||||||
|
JsonOutputFunctionsParser(args_only=False)
|
||||||
|
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
|
||||||
|
| RouterRunnable(runnables)
|
||||||
|
)
|
||||||
|
super().__init__(bound=router, kwargs={}, functions=functions)
|
@ -107,9 +107,11 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
|||||||
def invoke(
|
def invoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None
|
self, input: str, config: Optional[RunnableConfig] = None
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
_config: Dict[str, Any] = dict(config or {})
|
config = config or {}
|
||||||
_config.pop("_locals", None)
|
config_kwargs: Dict = {
|
||||||
return self.get_relevant_documents(input, **_config)
|
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||||
|
}
|
||||||
|
return self.get_relevant_documents(input, **config_kwargs)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, input: str, config: Optional[RunnableConfig] = None
|
self, input: str, config: Optional[RunnableConfig] = None
|
||||||
@ -118,9 +120,11 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
|
|||||||
# If the retriever doesn't implement async, use default implementation
|
# If the retriever doesn't implement async, use default implementation
|
||||||
return await super().ainvoke(input, config)
|
return await super().ainvoke(input, config)
|
||||||
|
|
||||||
_config: Dict[str, Any] = dict(config or {})
|
config = config or {}
|
||||||
_config.pop("_locals", None)
|
config_kwargs: Dict = {
|
||||||
return await self.aget_relevant_documents(input, **_config)
|
k: config.get(k) for k in ("callbacks", "tags", "metadata")
|
||||||
|
}
|
||||||
|
return await self.aget_relevant_documents(input, **config_kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
|
@ -1229,7 +1229,7 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|||||||
|
|
||||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that binds a runnable to a set of kwargs.
|
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bound: Runnable[Input, Output]
|
bound: Runnable[Input, Output]
|
||||||
@ -1314,8 +1314,15 @@ class RouterRunnable(
|
|||||||
|
|
||||||
runnables: Mapping[str, Runnable[Input, Output]]
|
runnables: Mapping[str, Runnable[Input, Output]]
|
||||||
|
|
||||||
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
|
def __init__(
|
||||||
super().__init__(runnables=runnables)
|
self,
|
||||||
|
runnables: Mapping[
|
||||||
|
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
@ -502,6 +502,18 @@ def _construct_run_evaluator(
|
|||||||
return run_evaluator
|
return run_evaluator
|
||||||
|
|
||||||
|
|
||||||
|
def _get_keys(
|
||||||
|
config: RunEvalConfig,
|
||||||
|
run_inputs: Optional[List[str]],
|
||||||
|
run_outputs: Optional[List[str]],
|
||||||
|
example_outputs: Optional[List[str]],
|
||||||
|
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||||
|
input_key = _determine_input_key(config, run_inputs)
|
||||||
|
prediction_key = _determine_prediction_key(config, run_outputs)
|
||||||
|
reference_key = _determine_reference_key(config, example_outputs)
|
||||||
|
return input_key, prediction_key, reference_key
|
||||||
|
|
||||||
|
|
||||||
def _load_run_evaluators(
|
def _load_run_evaluators(
|
||||||
config: RunEvalConfig,
|
config: RunEvalConfig,
|
||||||
run_type: str,
|
run_type: str,
|
||||||
@ -521,9 +533,13 @@ def _load_run_evaluators(
|
|||||||
"""
|
"""
|
||||||
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
|
eval_llm = config.eval_llm or ChatOpenAI(model="gpt-4", temperature=0.0)
|
||||||
run_evaluators = []
|
run_evaluators = []
|
||||||
input_key = _determine_input_key(config, run_inputs)
|
input_key, prediction_key, reference_key = None, None, None
|
||||||
prediction_key = _determine_prediction_key(config, run_outputs)
|
if config.evaluators or any(
|
||||||
reference_key = _determine_reference_key(config, example_outputs)
|
[isinstance(e, EvaluatorType) for e in config.evaluators]
|
||||||
|
):
|
||||||
|
input_key, prediction_key, reference_key = _get_keys(
|
||||||
|
config, run_inputs, run_outputs, example_outputs
|
||||||
|
)
|
||||||
for eval_config in config.evaluators:
|
for eval_config in config.evaluators:
|
||||||
run_evaluator = _construct_run_evaluator(
|
run_evaluator = _construct_run_evaluator(
|
||||||
eval_config,
|
eval_config,
|
||||||
@ -1078,15 +1094,15 @@ def _run_on_examples(
|
|||||||
A dictionary mapping example ids to the model outputs.
|
A dictionary mapping example ids to the model outputs.
|
||||||
"""
|
"""
|
||||||
results: Dict[str, Any] = {}
|
results: Dict[str, Any] = {}
|
||||||
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory)
|
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||||
project_name = _get_project_name(project_name, llm_or_chain_factory)
|
project_name = _get_project_name(project_name, wrapped_model)
|
||||||
tracer = LangChainTracer(
|
tracer = LangChainTracer(
|
||||||
project_name=project_name, client=client, use_threading=False
|
project_name=project_name, client=client, use_threading=False
|
||||||
)
|
)
|
||||||
run_evaluators, examples = _setup_evaluation(
|
run_evaluators, examples = _setup_evaluation(
|
||||||
llm_or_chain_factory, examples, evaluation, data_type
|
wrapped_model, examples, evaluation, data_type
|
||||||
)
|
)
|
||||||
examples = _validate_example_inputs(examples, llm_or_chain_factory, input_mapper)
|
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
|
||||||
evalution_handler = EvaluatorCallbackHandler(
|
evalution_handler = EvaluatorCallbackHandler(
|
||||||
evaluators=run_evaluators or [],
|
evaluators=run_evaluators or [],
|
||||||
client=client,
|
client=client,
|
||||||
@ -1095,7 +1111,7 @@ def _run_on_examples(
|
|||||||
for i, example in enumerate(examples):
|
for i, example in enumerate(examples):
|
||||||
result = _run_llm_or_chain(
|
result = _run_llm_or_chain(
|
||||||
example,
|
example,
|
||||||
llm_or_chain_factory,
|
wrapped_model,
|
||||||
num_repetitions,
|
num_repetitions,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
@ -1118,8 +1134,8 @@ def _prepare_eval_run(
|
|||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
project_name: Optional[str],
|
project_name: Optional[str],
|
||||||
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
|
) -> Tuple[MCF, str, Dataset, Iterator[Example]]:
|
||||||
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
|
||||||
project_name = _get_project_name(project_name, llm_or_chain_factory)
|
project_name = _get_project_name(project_name, wrapped_model)
|
||||||
try:
|
try:
|
||||||
project = client.create_project(project_name)
|
project = client.create_project(project_name)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -1134,7 +1150,7 @@ def _prepare_eval_run(
|
|||||||
)
|
)
|
||||||
dataset = client.read_dataset(dataset_name=dataset_name)
|
dataset = client.read_dataset(dataset_name=dataset_name)
|
||||||
examples = client.list_examples(dataset_id=str(dataset.id))
|
examples = client.list_examples(dataset_id=str(dataset.id))
|
||||||
return llm_or_chain_factory, project_name, dataset, examples
|
return wrapped_model, project_name, dataset, examples
|
||||||
|
|
||||||
|
|
||||||
async def arun_on_dataset(
|
async def arun_on_dataset(
|
||||||
@ -1260,13 +1276,13 @@ async def arun_on_dataset(
|
|||||||
evaluation=evaluation_config,
|
evaluation=evaluation_config,
|
||||||
)
|
)
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
|
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||||
client, dataset_name, llm_or_chain_factory, project_name
|
client, dataset_name, llm_or_chain_factory, project_name
|
||||||
)
|
)
|
||||||
results = await _arun_on_examples(
|
results = await _arun_on_examples(
|
||||||
client,
|
client,
|
||||||
examples,
|
examples,
|
||||||
llm_or_chain_factory,
|
wrapped_model,
|
||||||
concurrency_level=concurrency_level,
|
concurrency_level=concurrency_level,
|
||||||
num_repetitions=num_repetitions,
|
num_repetitions=num_repetitions,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
@ -1427,14 +1443,14 @@ def run_on_dataset(
|
|||||||
evaluation=evaluation_config,
|
evaluation=evaluation_config,
|
||||||
)
|
)
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
llm_or_chain_factory, project_name, dataset, examples = _prepare_eval_run(
|
wrapped_model, project_name, dataset, examples = _prepare_eval_run(
|
||||||
client, dataset_name, llm_or_chain_factory, project_name
|
client, dataset_name, llm_or_chain_factory, project_name
|
||||||
)
|
)
|
||||||
if concurrency_level in (0, 1):
|
if concurrency_level in (0, 1):
|
||||||
results = _run_on_examples(
|
results = _run_on_examples(
|
||||||
client,
|
client,
|
||||||
examples,
|
examples,
|
||||||
llm_or_chain_factory,
|
wrapped_model,
|
||||||
num_repetitions=num_repetitions,
|
num_repetitions=num_repetitions,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
@ -1448,7 +1464,7 @@ def run_on_dataset(
|
|||||||
coro = _arun_on_examples(
|
coro = _arun_on_examples(
|
||||||
client,
|
client,
|
||||||
examples,
|
examples,
|
||||||
llm_or_chain_factory,
|
wrapped_model,
|
||||||
concurrency_level=concurrency_level,
|
concurrency_level=concurrency_level,
|
||||||
num_repetitions=num_repetitions,
|
num_repetitions=num_repetitions,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
|
@ -203,7 +203,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
config = config or {}
|
config = config or {}
|
||||||
return self.run(input, **config, **kwargs)
|
return self.run(
|
||||||
|
input,
|
||||||
|
callbacks=config.get("callbacks"),
|
||||||
|
tags=config.get("tags"),
|
||||||
|
metadata=config.get("metadata"),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
@ -216,7 +222,13 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
|
|||||||
return super().ainvoke(input, config, **kwargs)
|
return super().ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
return await self.arun(input, **config, **kwargs)
|
return await self.arun(
|
||||||
|
input,
|
||||||
|
callbacks=config.get("callbacks"),
|
||||||
|
tags=config.get("tags"),
|
||||||
|
metadata=config.get("metadata"),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# --- Tool ---
|
# --- Tool ---
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain"
|
name = "langchain"
|
||||||
version = "0.0.259"
|
version = "0.0.260"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
authors = []
|
authors = []
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
|
@ -0,0 +1,9 @@
|
|||||||
|
"""Test the airbyte document loader.
|
||||||
|
|
||||||
|
Light test to ensure that the airbyte document loader can be imported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_airbyte_import() -> None:
|
||||||
|
"""Test that the airbyte document loader can be imported."""
|
||||||
|
from langchain.document_loaders import airbyte # noqa
|
@ -0,0 +1,31 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_openai_functions_router
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'description': 'Sends the draft for revision.',
|
||||||
|
'name': 'revise',
|
||||||
|
'parameters': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'notes': dict({
|
||||||
|
'description': "The editor's notes to guide the revision.",
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'description': 'Accepts the draft.',
|
||||||
|
'name': 'accept',
|
||||||
|
'parameters': dict({
|
||||||
|
'properties': dict({
|
||||||
|
'draft': dict({
|
||||||
|
'description': 'The draft to accept.',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': 'object',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
@ -0,0 +1,95 @@
|
|||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
|
||||||
|
from langchain.schema import ChatResult
|
||||||
|
from langchain.schema.messages import AIMessage, BaseMessage
|
||||||
|
from langchain.schema.output import ChatGeneration
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChatOpenAI(BaseChatModel):
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-openai-chat-model"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(
|
||||||
|
message=AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {
|
||||||
|
"name": "accept",
|
||||||
|
"arguments": '{\n "draft": "turtles"\n}',
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_functions_router(
|
||||||
|
snapshot: SnapshotAssertion, mocker: MockerFixture
|
||||||
|
) -> None:
|
||||||
|
revise = mocker.Mock(
|
||||||
|
side_effect=lambda kw: f'Revised draft: no more {kw["notes"]}!'
|
||||||
|
)
|
||||||
|
accept = mocker.Mock(side_effect=lambda kw: f'Accepted draft: {kw["draft"]}!')
|
||||||
|
|
||||||
|
router = OpenAIFunctionsRouter(
|
||||||
|
{
|
||||||
|
"revise": revise,
|
||||||
|
"accept": accept,
|
||||||
|
},
|
||||||
|
functions=[
|
||||||
|
{
|
||||||
|
"name": "revise",
|
||||||
|
"description": "Sends the draft for revision.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"notes": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The editor's notes to guide the revision.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "accept",
|
||||||
|
"description": "Accepts the draft.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"draft": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The draft to accept.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
model = FakeChatOpenAI()
|
||||||
|
|
||||||
|
chain = model.bind(functions=router.functions) | router
|
||||||
|
|
||||||
|
assert router.functions == snapshot
|
||||||
|
|
||||||
|
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"
|
||||||
|
|
||||||
|
revise.assert_not_called()
|
||||||
|
accept.assert_called_once_with({"draft": "turtles"})
|
Loading…
Reference in New Issue
Block a user