mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
docs, standard-tests: how to standard test a custom tool, imports (#27931)
This commit is contained in:
parent
39fcb476fd
commit
409c7946ac
223
docs/docs/how_to/tools_standard_tests.ipynb
Normal file
223
docs/docs/how_to/tools_standard_tests.ipynb
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# How to add standard tests to a tool\n",
|
||||||
|
"\n",
|
||||||
|
"When creating either a custom tool or a new tool to publish in a LangChain integration, it is important to add standard tests to ensure the tool works as expected. This guide will show you how to add standard tests to a tool.\n",
|
||||||
|
"\n",
|
||||||
|
"## Setup\n",
|
||||||
|
"\n",
|
||||||
|
"First, let's install 2 dependencies:\n",
|
||||||
|
"\n",
|
||||||
|
"- `langchain-core` will define the interfaces we want to import to define our custom tool.\n",
|
||||||
|
"- `langchain-tests==0.3.0` will provide the standard tests we want to use.\n",
|
||||||
|
"\n",
|
||||||
|
":::note\n",
|
||||||
|
"\n",
|
||||||
|
"The `langchain-tests` package contains the module `langchain_standard_tests`. This name\n",
|
||||||
|
"mistmatch is due to this package historically being called `langchain_standard_tests` and\n",
|
||||||
|
"the name not being available on PyPi. This will either be reconciled by our \n",
|
||||||
|
"[PEP 541 request](https://github.com/pypi/support/issues/5062) (we welcome upvotes!), \n",
|
||||||
|
"or in a new release of `langchain-tests`.\n",
|
||||||
|
"\n",
|
||||||
|
"Because added tests in new versions of `langchain-tests` will always break your CI/CD pipelines, we recommend pinning the \n",
|
||||||
|
"version of `langchain-tests==0.3.0` to avoid unexpected changes.\n",
|
||||||
|
"\n",
|
||||||
|
":::"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install -U langchain-core langchain-tests==0.3.0 pytest pytest-socket"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Let's say we're publishing a package, `langchain_parrot_link`, that exposes a\n",
|
||||||
|
"tool called `ParrotMultiplyTool`:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# title=\"langchain_parrot_link/tools.py\"\n",
|
||||||
|
"from langchain_core.tools import BaseTool\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class ParrotMultiplyTool(BaseTool):\n",
|
||||||
|
" name: str = \"ParrotMultiplyTool\"\n",
|
||||||
|
" description: str = (\n",
|
||||||
|
" \"Multiply two numbers like a parrot. Parrots always add \"\n",
|
||||||
|
" \"eighty for their matey.\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" def _run(self, a: int, b: int) -> int:\n",
|
||||||
|
" return a * b + 80"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"And we'll assume you've structured your package the same way as the main LangChain\n",
|
||||||
|
"packages:\n",
|
||||||
|
"\n",
|
||||||
|
"```\n",
|
||||||
|
"/\n",
|
||||||
|
"├── langchain_parrot_link/\n",
|
||||||
|
"│ └── tools.py\n",
|
||||||
|
"└── tests/\n",
|
||||||
|
" ├── unit_tests/\n",
|
||||||
|
" │ └── test_tools.py\n",
|
||||||
|
" └── integration_tests/\n",
|
||||||
|
" └── test_tools.py\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"## Add and configure standard tests\n",
|
||||||
|
"\n",
|
||||||
|
"There are 2 namespaces in the `langchain-tests` package: \n",
|
||||||
|
"\n",
|
||||||
|
"- unit tests (`langchain_standard_tests.unit_tests`): designed to be used to test the tool in isolation and without access to external services\n",
|
||||||
|
"- integration tests (`langchain_standard_tests.integration_tests`): designed to be used to test the tool with access to external services (in particular, the external service that the tool is designed to interact with).\n",
|
||||||
|
"\n",
|
||||||
|
":::note\n",
|
||||||
|
"\n",
|
||||||
|
"Integration tests can also be run without access to external services, **if** they are properly mocked.\n",
|
||||||
|
"\n",
|
||||||
|
":::\n",
|
||||||
|
"\n",
|
||||||
|
"Both types of tests are implemented as [`pytest` class-based test suites](https://docs.pytest.org/en/7.1.x/getting-started.html#group-multiple-tests-in-a-class).\n",
|
||||||
|
"\n",
|
||||||
|
"By subclassing the base classes for each type of standard test (see below), you get all of the standard tests for that type, and you\n",
|
||||||
|
"can override the properties that the test suite uses to configure the tests.\n",
|
||||||
|
"\n",
|
||||||
|
"### Standard tools tests\n",
|
||||||
|
"\n",
|
||||||
|
"Here's how you would configure the standard unit tests for the custom tool, e.g. in `tests/test_tools.py`:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"title": "tests/test_custom_tool.py"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# title=\"tests/unit_tests/test_custom_tool.py\"\n",
|
||||||
|
"from typing import Type\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain_parrot_link.tools import ParrotMultiplyTool\n",
|
||||||
|
"from langchain_standard_tests.unit_tests import ToolsUnitTests\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class MultiplyToolUnitTests(ToolsUnitTests):\n",
|
||||||
|
" @property\n",
|
||||||
|
" def tool_constructor(self) -> Type[ParrotMultiplyTool]:\n",
|
||||||
|
" return ParrotMultiplyTool\n",
|
||||||
|
"\n",
|
||||||
|
" def tool_constructor_params(self) -> dict:\n",
|
||||||
|
" # if your tool constructor instead required initialization arguments like\n",
|
||||||
|
" # `def __init__(self, some_arg: int):`, you would return those here\n",
|
||||||
|
" # as a dictionary, e.g.: `return {'some_arg': 42}`\n",
|
||||||
|
" return {}\n",
|
||||||
|
"\n",
|
||||||
|
" def tool_invoke_params_example(self) -> dict:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Returns a dictionary representing the \"args\" of an example tool call.\n",
|
||||||
|
"\n",
|
||||||
|
" This should NOT be a ToolCall dict - i.e. it should not\n",
|
||||||
|
" have {\"name\", \"id\", \"args\"} keys.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" return {\"a\": 2, \"b\": 3}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# title=\"tests/integration_tests/test_custom_tool.py\"\n",
|
||||||
|
"from typing import Type\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain_parrot_link.tools import ParrotMultiplyTool\n",
|
||||||
|
"from langchain_standard_tests.integration_tests import ToolsIntegrationTests\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class MultiplyToolIntegrationTests(ToolsIntegrationTests):\n",
|
||||||
|
" @property\n",
|
||||||
|
" def tool_constructor(self) -> Type[ParrotMultiplyTool]:\n",
|
||||||
|
" return ParrotMultiplyTool\n",
|
||||||
|
"\n",
|
||||||
|
" def tool_constructor_params(self) -> dict:\n",
|
||||||
|
" # if your tool constructor instead required initialization arguments like\n",
|
||||||
|
" # `def __init__(self, some_arg: int):`, you would return those here\n",
|
||||||
|
" # as a dictionary, e.g.: `return {'some_arg': 42}`\n",
|
||||||
|
" return {}\n",
|
||||||
|
"\n",
|
||||||
|
" def tool_invoke_params_example(self) -> dict:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Returns a dictionary representing the \"args\" of an example tool call.\n",
|
||||||
|
"\n",
|
||||||
|
" This should NOT be a ToolCall dict - i.e. it should not\n",
|
||||||
|
" have {\"name\", \"id\", \"args\"} keys.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" return {\"a\": 2, \"b\": 3}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"and you would run these with the following commands from your project root\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"# run unit tests without network access\n",
|
||||||
|
"pytest --disable-socket --enable-unix-socket tests/unit_tests\n",
|
||||||
|
"\n",
|
||||||
|
"# run integration tests\n",
|
||||||
|
"pytest tests/integration_tests\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -24,6 +24,16 @@ class EscapePreprocessor(Preprocessor):
|
|||||||
# escape ``` in code
|
# escape ``` in code
|
||||||
cell.source = cell.source.replace("```", r"\`\`\`")
|
cell.source = cell.source.replace("```", r"\`\`\`")
|
||||||
# escape ``` in output
|
# escape ``` in output
|
||||||
|
|
||||||
|
# allow overriding title based on comment at beginning of cell
|
||||||
|
if cell.source.startswith("# title="):
|
||||||
|
lines = cell.source.split("\n")
|
||||||
|
title = lines[0].split("# title=")[1]
|
||||||
|
if title.startswith('"') and title.endswith('"'):
|
||||||
|
title = title[1:-1]
|
||||||
|
cell.metadata["title"] = title
|
||||||
|
cell.source = "\n".join(lines[1:])
|
||||||
|
|
||||||
if "outputs" in cell:
|
if "outputs" in cell:
|
||||||
filter_out = set()
|
filter_out = set()
|
||||||
for i, output in enumerate(cell["outputs"]):
|
for i, output in enumerate(cell["outputs"]):
|
||||||
|
@ -1,5 +1,20 @@
|
|||||||
{% extends 'markdown/index.md.j2' %}
|
{% extends 'markdown/index.md.j2' %}
|
||||||
|
|
||||||
|
{% block input %}
|
||||||
|
```
|
||||||
|
{%- if 'magics_language' in cell.metadata -%}
|
||||||
|
{{ cell.metadata.magics_language}}
|
||||||
|
{%- elif 'name' in nb.metadata.get('language_info', {}) -%}
|
||||||
|
{{ nb.metadata.language_info.name }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if 'title' in cell.metadata -%}
|
||||||
|
{{ ' ' }}title="{{ cell.metadata.title }}"
|
||||||
|
|
||||||
|
{%- endif %}
|
||||||
|
{{ cell.source}}
|
||||||
|
```
|
||||||
|
{% endblock input %}
|
||||||
|
|
||||||
{%- block traceback_line -%}
|
{%- block traceback_line -%}
|
||||||
```output
|
```output
|
||||||
{{ line.rstrip() | strip_ansi }}
|
{{ line.rstrip() | strip_ansi }}
|
||||||
|
@ -10,6 +10,7 @@ modules = [
|
|||||||
"chat_models",
|
"chat_models",
|
||||||
"vectorstores",
|
"vectorstores",
|
||||||
"embeddings",
|
"embeddings",
|
||||||
|
"tools",
|
||||||
]
|
]
|
||||||
|
|
||||||
for module in modules:
|
for module in modules:
|
||||||
@ -17,14 +18,21 @@ for module in modules:
|
|||||||
f"langchain_standard_tests.integration_tests.{module}"
|
f"langchain_standard_tests.integration_tests.{module}"
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_standard_tests.integration_tests.chat_models import (
|
from .base_store import BaseStoreAsyncTests, BaseStoreSyncTests
|
||||||
ChatModelIntegrationTests,
|
from .cache import AsyncCacheTestSuite, SyncCacheTestSuite
|
||||||
)
|
from .chat_models import ChatModelIntegrationTests
|
||||||
from langchain_standard_tests.integration_tests.embeddings import (
|
from .embeddings import EmbeddingsIntegrationTests
|
||||||
EmbeddingsIntegrationTests,
|
from .tools import ToolsIntegrationTests
|
||||||
)
|
from .vectorstores import AsyncReadWriteTestSuite, ReadWriteTestSuite
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatModelIntegrationTests",
|
"ChatModelIntegrationTests",
|
||||||
"EmbeddingsIntegrationTests",
|
"EmbeddingsIntegrationTests",
|
||||||
|
"ToolsIntegrationTests",
|
||||||
|
"BaseStoreAsyncTests",
|
||||||
|
"BaseStoreSyncTests",
|
||||||
|
"AsyncCacheTestSuite",
|
||||||
|
"SyncCacheTestSuite",
|
||||||
|
"AsyncReadWriteTestSuite",
|
||||||
|
"ReadWriteTestSuite",
|
||||||
]
|
]
|
||||||
|
@ -7,11 +7,14 @@ import pytest
|
|||||||
modules = [
|
modules = [
|
||||||
"chat_models",
|
"chat_models",
|
||||||
"embeddings",
|
"embeddings",
|
||||||
|
"tools",
|
||||||
]
|
]
|
||||||
|
|
||||||
for module in modules:
|
for module in modules:
|
||||||
pytest.register_assert_rewrite(f"langchain_standard_tests.unit_tests.{module}")
|
pytest.register_assert_rewrite(f"langchain_standard_tests.unit_tests.{module}")
|
||||||
|
|
||||||
from langchain_standard_tests.unit_tests.chat_models import ChatModelUnitTests
|
from .chat_models import ChatModelUnitTests
|
||||||
|
from .embeddings import EmbeddingsUnitTests
|
||||||
|
from .tools import ToolsUnitTests
|
||||||
|
|
||||||
__all__ = ["ChatModelUnitTests", "EmbeddingsUnitTests"]
|
__all__ = ["ChatModelUnitTests", "EmbeddingsUnitTests", "ToolsUnitTests"]
|
||||||
|
Loading…
Reference in New Issue
Block a user