community: move to separate repo (#31060)

langchain-community is moving to
https://github.com/langchain-ai/langchain-community
This commit is contained in:
ccurme
2025-04-29 09:22:04 -04:00
committed by GitHub
parent 7e926520d5
commit 9ff5b5d282
2349 changed files with 30 additions and 372332 deletions

View File

@@ -1,8 +1,8 @@
Thank you for contributing to LangChain!
- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core, etc. is being modified. Use "docs: ..." for purely docs changes, "infra: ..." for CI changes.
- Example: "community: add foobar LLM"
- Where "package" is whichever of langchain, core, etc. is being modified. Use "docs: ..." for purely docs changes, "infra: ..." for CI changes.
- Example: "core: add foobar LLM"
- [ ] **PR message**: ***Delete this entire checklist*** and replace with
@@ -24,6 +24,5 @@ Additional guidelines:
- Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in langchain.
If no one reviews your PR within a few days, please @-mention one of baskaryan, eyurtsev, ccurme, vbarda, hwchase17.

View File

@@ -16,7 +16,6 @@ LANGCHAIN_DIRS = [
"libs/core",
"libs/text-splitters",
"libs/langchain",
"libs/community",
]
# when set to True, we are ignoring core dependents
@@ -134,12 +133,6 @@ def _get_configs_for_single_dir(job: str, dir_: str) -> List[Dict[str, str]]:
elif dir_ == "libs/langchain" and job == "extended-tests":
py_versions = ["3.9", "3.13"]
elif dir_ == "libs/community" and job == "extended-tests":
py_versions = ["3.9", "3.12"]
elif dir_ == "libs/community" and job == "compile-integration-tests":
# community integration deps are slow in 3.12
py_versions = ["3.9", "3.11"]
elif dir_ == ".":
# unable to install with 3.13 because tokenizers doesn't support 3.13 yet
py_versions = ["3.9", "3.12"]
@@ -184,11 +177,6 @@ def _get_pydantic_test_configs(
else "0"
)
custom_mins = {
# depends on pydantic-settings 2.4 which requires pydantic 2.7
"libs/community": 7,
}
max_pydantic_minor = min(
int(dir_max_pydantic_minor),
int(core_max_pydantic_minor),
@@ -196,7 +184,6 @@ def _get_pydantic_test_configs(
min_pydantic_minor = max(
int(dir_min_pydantic_minor),
int(core_min_pydantic_minor),
custom_mins.get(dir_, 0),
)
configs = [

View File

@@ -22,7 +22,6 @@ import re
MIN_VERSION_LIBS = [
"langchain-core",
"langchain-community",
"langchain",
"langchain-text-splitters",
"numpy",
@@ -35,7 +34,6 @@ SKIP_IF_PULL_REQUEST = [
"langchain-core",
"langchain-text-splitters",
"langchain",
"langchain-community",
]

View File

@@ -20,6 +20,8 @@ def get_target_dir(package_name: str) -> Path:
base_path = Path("langchain/libs")
if package_name_short == "experimental":
return base_path / "experimental"
if package_name_short == "community":
return base_path / "community"
return base_path / "partners" / package_name_short

View File

@@ -1,4 +1,3 @@
libs/community/langchain_community/llms/yuan2.py
"NotIn": "not in",
- `/checkin`: Check-in
docs/docs/integrations/providers/trulens.mdx

View File

@@ -34,11 +34,6 @@ jobs:
shell: bash
run: uv sync --group test --group test_integration
- name: Install deps outside pyproject
if: ${{ startsWith(inputs.working-directory, 'libs/community/') }}
shell: bash
run: VIRTUAL_ENV=.venv uv pip install "boto3<2" "google-cloud-aiplatform<2"
- name: Run integration tests
shell: bash
env:

View File

@@ -30,7 +30,7 @@ jobs:
- name: Install langchain editable
run: |
VIRTUAL_ENV=.venv uv pip install langchain-experimental -e libs/core libs/langchain libs/community
VIRTUAL_ENV=.venv uv pip install langchain-experimental langchain-community -e libs/core libs/langchain
- name: Check doc imports
shell: bash

View File

@@ -7,12 +7,6 @@ repos:
entry: make -C libs/core format
files: ^libs/core/
pass_filenames: false
- id: community
name: format community
language: system
entry: make -C libs/community format
files: ^libs/community/
pass_filenames: false
- id: langchain
name: format langchain
language: system

View File

@@ -13,23 +13,33 @@ Install `uv`: **[documentation on how to install it](https://docs.astral.sh/uv/g
This repository contains multiple packages:
- `langchain-core`: Base interfaces for key abstractions as well as logic for combining them in chains (LangChain Expression Language).
- `langchain-community`: Third-party integrations of various components.
- `langchain`: Chains, agents, and retrieval logic that makes up the cognitive architecture of your applications.
- `langchain-experimental`: Components and chains that are experimental, either in the sense that the techniques are novel and still being tested, or they require giving the LLM more access than would be possible in most production systems.
- Partner integrations: Partner packages in `libs/partners` that are independently version controlled.
:::note
Some LangChain packages live outside the monorepo, see for example
[langchain-community](https://github.com/langchain-ai/langchain-community) for various
third-party integrations and
[langchain-experimental](https://github.com/langchain-ai/langchain-experimental) for
abstractions that are experimental (either in the sense that the techniques are novel
and still being tested, or they require giving the LLM more access than would be
possible in most production systems).
:::
Each of these has its own development environment. Docs are run from the top-level makefile, but development
is split across separate test & release flows.
For this quickstart, start with langchain-community:
For this quickstart, start with `langchain`:
```bash
cd libs/community
cd libs/langchain
```
## Local Development Dependencies
Install langchain-community development requirements (for running langchain, running examples, linting, formatting, tests, and coverage):
Install development requirements (for running langchain, running examples, linting, formatting, tests, and coverage):
```bash
uv sync
@@ -62,22 +72,15 @@ make docker_tests
There are also [integration tests and code-coverage](../testing.mdx) available.
### Only develop langchain_core or langchain_community
### Developing langchain_core
If you are only developing `langchain_core` or `langchain_community`, you can simply install the dependencies for the respective projects and run tests:
If you are only developing `langchain_core`, you can simply install the dependencies for the project and run tests:
```bash
cd libs/core
make test
```
Or:
```bash
cd libs/community
make test
```
## Formatting and Linting
Run these locally before submitting a PR; the CI system will check also.

View File

@@ -1,79 +0,0 @@
.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
integration_tests: TEST_FILE = tests/integration_tests/
.EXPORT_ALL_VARIABLES:
UV_FROZEN = true
# Run unit tests and generate a coverage report.
coverage:
uv run --group test pytest --cov \
--cov-config=.coveragerc \
--cov-report xml \
--cov-report term-missing:skip-covered \
$(TEST_FILE)
test tests:
uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
integration_tests:
uv run --group test --group test_integration pytest $(TEST_FILE)
test_watch:
uv run --group test ptw --disable-socket --allow-unix-socket --snapshot-update --now . -- -vv tests/unit_tests
check_imports: $(shell find langchain_community -name '*.py')
uv run --group test python ./scripts/check_imports.py $^
extended_tests:
uv run --no-sync --group test pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/community --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_community
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh .
./scripts/check_pickle.sh .
[ "$(PYTHON_FILES)" = "" ] || uv run --group typing --group lint ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --group typing --group lint ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --group typing --group lint mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || uv run --group typing --group lint ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --group typing --group lint ruff check --select I --fix $(PYTHON_FILES)
spell_check:
uv run --group typing --group lint codespell --toml pyproject.toml
spell_fix:
uv run --group typing --group lint codespell --toml pyproject.toml -w
######################
# HELP
######################
help:
@echo '----'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'
@echo 'test_watch - run unit tests in watch mode'

View File

@@ -1,30 +1,3 @@
# 🦜️🧑‍🤝‍🧑 LangChain Community
This package has moved!
[![Downloads](https://static.pepy.tech/badge/langchain_community/month)](https://pepy.tech/project/langchain_community)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
## Quick Install
```bash
pip install langchain-community
```
## What is it?
LangChain Community contains third-party integrations that implement the base interfaces defined in LangChain Core, making them ready-to-use in any LangChain application.
For full documentation see the [API reference](https://python.langchain.com/api_reference/community/index.html).
![Diagram outlining the hierarchical organization of the LangChain framework, displaying the interconnected parts across multiple layers.](https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/static/svg/langchain_stack_112024.svg "LangChain Framework Overview")
## 📕 Releases & Versioning
`langchain-community` is currently on version `0.0.x`
All changes will be accompanied by a patch version increase.
## 💁 Contributing
As an open-source project in a rapidly developing field, we are extremely open to contributions, whether it be in the form of a new feature, improved infrastructure, or better documentation.
For detailed information on how to contribute, see the [Contributing Guide](https://python.langchain.com/docs/contributing/).
https://github.com/langchain-ai/langchain-community/tree/main/libs/community

View File

@@ -1,107 +0,0 @@
aiosqlite>=0.19.0,<0.20
aleph-alpha-client>=2.15.0,<3
anthropic>=0.3.11,<0.4
arxiv>=1.4,<2
assemblyai>=0.17.0,<0.18
atlassian-python-api>=3.36.0,<4
azure-ai-documentintelligence>=1.0.0b1,<2
azure-identity>=1.15.0,<2
azure-search-documents==11.4.0
azure.ai.vision.imageanalysis>=1.0.0,<2
beautifulsoup4>=4,<5
bibtexparser>=1.4.0,<2
cassio>=0.1.6,<0.2
chardet>=5.1.0,<6
cloudpathlib>=0.18,<0.19
cloudpickle>=2.0.0
cohere>=4,<6
databricks-vectorsearch>=0.21,<0.22
datasets>=2.15.0,<3
dgml-utils>=0.3.0,<0.4
elasticsearch>=8.12.0,<9
esprima>=4.0.1,<5
faiss-cpu>=1,<2
feedparser>=6.0.10,<7
fireworks-ai>=0.9.0,<0.10
friendli-client>=1.2.4,<2
geopandas>=0.13.1
gitpython>=3.1.32,<4
gliner>=0.2.7
google-cloud-documentai>=2.20.1,<3
gql>=3.4.1,<4
gradientai>=1.4.0,<2
graphviz>=0.20.3,<0.21
hdbcli>=2.19.21,<3
hologres-vector==0.0.6
html2text>=2020.1.16
httpx>=0.24.1,<0.25
httpx-sse>=0.4.0,<0.5
jinja2>=3,<4
jq>=1.4.1,<2
jsonschema>1
keybert>=0.8.5
langchain_openai>=0.2.1
litellm>=1.30,<=1.39.5
lxml>=4.9.3,<6.0
markdownify>=0.11.6,<0.12
motor>=3.3.1,<4
msal>=1.25.0,<2
mwparserfromhell>=0.6.4,<0.7
mwxml>=0.3.3,<0.4
needle-python>=0.4
networkx>=3.2.1,<4
newspaper3k>=0.2.8,<0.3
numexpr>=2.8.6,<3
nvidia-riva-client>=2.14.0,<3
oci>=2.128.0,<3
openai<2
openapi-pydantic>=0.3.2,<0.4
oracle-ads>=2.9.1,<3
oracledb>=2.2.0,<3
pandas>=2.0.1,<3
pdfminer-six==20231228
pdfplumber>=0.11
pgvector>=0.1.6,<0.2
playwright>=1.48.0,<2
praw>=7.7.1,<8
premai>=0.3.25,<0.4,!=0.3.100
psychicapi>=0.8.0,<0.9
pydantic>=2.7.4,<3
pytesseract>=0.3.13
py-trello>=0.19.0,<0.20
pyjwt>=2.8.0,<3
pymupdf>=1.22.3,<2
pypdf>=3.4.0,<5
pypdfium2>=4.10.0,<5
rank-bm25>=0.2.2,<0.3
rapidfuzz>=3.1.1,<4
rapidocr-onnxruntime>=1.3.2,<2
rdflib==7.0.0
requests-toolbelt>=1.0.0,<2
rspace_client>=2.5.0,<3
scikit-learn>=1.2.2,<2
simsimd>=5.0.0,<6
sqlite-vss>=0.1.2,<0.2
sqlite-vec>=0.1.0,<0.2
sseclient-py>=1.8.0,<2
streamlit>=1.18.0,<2
sympy>=1.12,<2
telethon>=1.28.5,<2
tidb-vector>=0.0.3,<1.0.0
timescale-vector==0.0.1
tqdm>=4.48.0
tiktoken>=0.8.0
tree-sitter>=0.20.2,<0.21
tree-sitter-languages>=1.8.0,<2
upstash-redis>=1.1.0,<2
upstash-ratelimit>=1.1.0,<2
vdms>=0.0.20
xata>=1.0.0a7,<2
xmltodict>=0.13.0,<0.14
nanopq==0.2.1
mlflow[genai]>=2.14.0
databricks-sdk>=0.30.0
websocket>=0.2.1,<1
writer-sdk>=1.2.0
yandexcloud==0.144.0
unstructured[pdf]>=0.15

View File

@@ -1,10 +0,0 @@
"""Main entrypoint into package."""
from importlib import metadata
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""
del metadata # optional, avoids polluting the results of dir(__package__)

View File

@@ -1,8 +0,0 @@
"""**Adapters** are used to adapt LangChain models to other APIs.
LangChain integrates with many model providers.
While LangChain has its own message and model APIs,
LangChain has also made it as easy as
possible to explore other models by exposing an **adapter** to adapt LangChain
models to the other APIs, as to the OpenAI API.
"""

View File

@@ -1,421 +0,0 @@
from __future__ import annotations
import importlib
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
List,
Mapping,
Sequence,
Union,
overload,
)
from langchain_core.chat_sessions import ChatSession
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from pydantic import BaseModel
from typing_extensions import Literal
async def aenumerate(
iterable: AsyncIterator[Any], start: int = 0
) -> AsyncIterator[tuple[int, Any]]:
"""Async version of enumerate function."""
i = start
async for x in iterable:
yield i, x
i += 1
class IndexableBaseModel(BaseModel):
"""Allows a BaseModel to return its fields by string variable indexing."""
def __getitem__(self, item: str) -> Any:
return getattr(self, item)
class Choice(IndexableBaseModel):
"""Choice."""
message: dict
class ChatCompletions(IndexableBaseModel):
"""Chat completions."""
choices: List[Choice]
class ChoiceChunk(IndexableBaseModel):
"""Choice chunk."""
delta: dict
class ChatCompletionChunk(IndexableBaseModel):
"""Chat completion chunk."""
choices: List[ChoiceChunk]
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict.get("role")
if role == "user":
return HumanMessage(content=_dict.get("content", ""))
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
if tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = tool_calls
if context := _dict.get("context"):
additional_kwargs["context"] = context
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""))
elif role == "function":
return FunctionMessage(content=_dict.get("content", ""), name=_dict.get("name")) # type: ignore[arg-type]
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=_dict.get("tool_call_id"),
additional_kwargs=additional_kwargs,
)
else:
return ChatMessage(content=_dict.get("content", ""), role=role) # type: ignore[arg-type]
def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "context" in message.additional_kwargs:
message_dict["context"] = message.additional_kwargs["context"]
# If context only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
"""Convert dictionaries representing OpenAI messages to LangChain format.
Args:
messages: List of dictionaries representing OpenAI messages
Returns:
List of LangChain BaseMessage objects.
"""
return [convert_dict_to_message(m) for m in messages]
def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict:
_dict: Dict[str, Any] = {}
if isinstance(chunk, AIMessageChunk):
if i == 0:
# Only shows up in the first chunk
_dict["role"] = "assistant"
if "function_call" in chunk.additional_kwargs:
_dict["function_call"] = chunk.additional_kwargs["function_call"]
# If the first chunk is a function call, the content is not empty string,
# not missing, but None.
if i == 0:
_dict["content"] = None
if "tool_calls" in chunk.additional_kwargs:
_dict["tool_calls"] = chunk.additional_kwargs["tool_calls"]
# If the first chunk is tool calls, the content is not empty string,
# not missing, but None.
if i == 0:
_dict["content"] = None
else:
_dict["content"] = chunk.content
else:
raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}")
# This only happens at the end of streams, and OpenAI returns as empty dict
if _dict == {"content": ""}:
_dict = {}
return _dict
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
_dict = _convert_message_chunk(chunk, i)
return {"choices": [{"delta": _dict}]}
class ChatCompletion:
"""Chat completion."""
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> dict: ...
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> Iterable: ...
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[dict, Iterable]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
for i, c in enumerate(model_config.stream(converted_messages))
)
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> dict: ...
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> AsyncIterator: ...
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[dict, AsyncIterator]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = await model_config.ainvoke(converted_messages)
return {"choices": [{"message": convert_message_to_dict(result)}]}
else:
return (
_convert_message_chunk_to_delta(c, i)
async for i, c in aenumerate(model_config.astream(converted_messages))
)
def _has_assistant_message(session: ChatSession) -> bool:
"""Check if chat session has an assistant message."""
return any([isinstance(m, AIMessage) for m in session["messages"]])
def convert_messages_for_finetuning(
sessions: Iterable[ChatSession],
) -> List[List[dict]]:
"""Convert messages to a list of lists of dictionaries for fine-tuning.
Args:
sessions: The chat sessions.
Returns:
The list of lists of dictionaries.
"""
return [
[convert_message_to_dict(s) for s in session["messages"]]
for session in sessions
if _has_assistant_message(session)
]
class Completions:
"""Completions."""
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> ChatCompletions: ...
@overload
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> Iterable: ...
@staticmethod
def create(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[ChatCompletions, Iterable]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = model_config.invoke(converted_messages)
return ChatCompletions(
choices=[Choice(message=convert_message_to_dict(result))]
)
else:
return (
ChatCompletionChunk(
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
)
for i, c in enumerate(model_config.stream(converted_messages))
)
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[False] = False,
**kwargs: Any,
) -> ChatCompletions: ...
@overload
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: Literal[True],
**kwargs: Any,
) -> AsyncIterator: ...
@staticmethod
async def acreate(
messages: Sequence[Dict[str, Any]],
*,
provider: str = "ChatOpenAI",
stream: bool = False,
**kwargs: Any,
) -> Union[ChatCompletions, AsyncIterator]:
models = importlib.import_module("langchain.chat_models")
model_cls = getattr(models, provider)
model_config = model_cls(**kwargs)
converted_messages = convert_openai_messages(messages)
if not stream:
result = await model_config.ainvoke(converted_messages)
return ChatCompletions(
choices=[Choice(message=convert_message_to_dict(result))]
)
else:
return (
ChatCompletionChunk(
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
)
async for i, c in aenumerate(model_config.astream(converted_messages))
)
class Chat:
"""Chat."""
def __init__(self) -> None:
self.completions = Completions()
chat = Chat()

View File

@@ -1,170 +0,0 @@
"""**Toolkits** are sets of tools that can be used to interact with
various services and APIs.
"""
import importlib
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.agent_toolkits.ainetwork.toolkit import (
AINetworkToolkit,
)
from langchain_community.agent_toolkits.amadeus.toolkit import (
AmadeusToolkit,
)
from langchain_community.agent_toolkits.azure_ai_services import (
AzureAiServicesToolkit,
)
from langchain_community.agent_toolkits.azure_cognitive_services import (
AzureCognitiveServicesToolkit,
)
from langchain_community.agent_toolkits.cassandra_database.toolkit import (
CassandraDatabaseToolkit, # noqa: F401
)
from langchain_community.agent_toolkits.cogniswitch.toolkit import (
CogniswitchToolkit,
)
from langchain_community.agent_toolkits.connery import (
ConneryToolkit,
)
from langchain_community.agent_toolkits.file_management.toolkit import (
FileManagementToolkit,
)
from langchain_community.agent_toolkits.gmail.toolkit import (
GmailToolkit,
)
from langchain_community.agent_toolkits.jira.toolkit import (
JiraToolkit,
)
from langchain_community.agent_toolkits.json.base import (
create_json_agent,
)
from langchain_community.agent_toolkits.json.toolkit import (
JsonToolkit,
)
from langchain_community.agent_toolkits.multion.toolkit import (
MultionToolkit,
)
from langchain_community.agent_toolkits.nasa.toolkit import (
NasaToolkit,
)
from langchain_community.agent_toolkits.nla.toolkit import (
NLAToolkit,
)
from langchain_community.agent_toolkits.office365.toolkit import (
O365Toolkit,
)
from langchain_community.agent_toolkits.openapi.base import (
create_openapi_agent,
)
from langchain_community.agent_toolkits.openapi.toolkit import (
OpenAPIToolkit,
)
from langchain_community.agent_toolkits.playwright.toolkit import (
PlayWrightBrowserToolkit,
)
from langchain_community.agent_toolkits.polygon.toolkit import (
PolygonToolkit,
)
from langchain_community.agent_toolkits.powerbi.base import (
create_pbi_agent,
)
from langchain_community.agent_toolkits.powerbi.chat_base import (
create_pbi_chat_agent,
)
from langchain_community.agent_toolkits.powerbi.toolkit import (
PowerBIToolkit,
)
from langchain_community.agent_toolkits.slack.toolkit import (
SlackToolkit,
)
from langchain_community.agent_toolkits.spark_sql.base import (
create_spark_sql_agent,
)
from langchain_community.agent_toolkits.spark_sql.toolkit import (
SparkSQLToolkit,
)
from langchain_community.agent_toolkits.sql.base import (
create_sql_agent,
)
from langchain_community.agent_toolkits.sql.toolkit import (
SQLDatabaseToolkit,
)
from langchain_community.agent_toolkits.steam.toolkit import (
SteamToolkit,
)
from langchain_community.agent_toolkits.zapier.toolkit import (
ZapierToolkit,
)
__all__ = [
"AINetworkToolkit",
"AmadeusToolkit",
"AzureAiServicesToolkit",
"AzureCognitiveServicesToolkit",
"CogniswitchToolkit",
"ConneryToolkit",
"FileManagementToolkit",
"GmailToolkit",
"JiraToolkit",
"JsonToolkit",
"MultionToolkit",
"NLAToolkit",
"NasaToolkit",
"O365Toolkit",
"OpenAPIToolkit",
"PlayWrightBrowserToolkit",
"PolygonToolkit",
"PowerBIToolkit",
"SQLDatabaseToolkit",
"SlackToolkit",
"SparkSQLToolkit",
"SteamToolkit",
"ZapierToolkit",
"create_json_agent",
"create_openapi_agent",
"create_pbi_agent",
"create_pbi_chat_agent",
"create_spark_sql_agent",
"create_sql_agent",
]
_module_lookup = {
"AINetworkToolkit": "langchain_community.agent_toolkits.ainetwork.toolkit",
"AmadeusToolkit": "langchain_community.agent_toolkits.amadeus.toolkit",
"AzureAiServicesToolkit": "langchain_community.agent_toolkits.azure_ai_services",
"AzureCognitiveServicesToolkit": "langchain_community.agent_toolkits.azure_cognitive_services", # noqa: E501
"CogniswitchToolkit": "langchain_community.agent_toolkits.cogniswitch.toolkit",
"ConneryToolkit": "langchain_community.agent_toolkits.connery",
"FileManagementToolkit": "langchain_community.agent_toolkits.file_management.toolkit", # noqa: E501
"GmailToolkit": "langchain_community.agent_toolkits.gmail.toolkit",
"JiraToolkit": "langchain_community.agent_toolkits.jira.toolkit",
"JsonToolkit": "langchain_community.agent_toolkits.json.toolkit",
"MultionToolkit": "langchain_community.agent_toolkits.multion.toolkit",
"NLAToolkit": "langchain_community.agent_toolkits.nla.toolkit",
"NasaToolkit": "langchain_community.agent_toolkits.nasa.toolkit",
"O365Toolkit": "langchain_community.agent_toolkits.office365.toolkit",
"OpenAPIToolkit": "langchain_community.agent_toolkits.openapi.toolkit",
"PlayWrightBrowserToolkit": "langchain_community.agent_toolkits.playwright.toolkit",
"PolygonToolkit": "langchain_community.agent_toolkits.polygon.toolkit",
"PowerBIToolkit": "langchain_community.agent_toolkits.powerbi.toolkit",
"SQLDatabaseToolkit": "langchain_community.agent_toolkits.sql.toolkit",
"SlackToolkit": "langchain_community.agent_toolkits.slack.toolkit",
"SparkSQLToolkit": "langchain_community.agent_toolkits.spark_sql.toolkit",
"SteamToolkit": "langchain_community.agent_toolkits.steam.toolkit",
"ZapierToolkit": "langchain_community.agent_toolkits.zapier.toolkit",
"create_json_agent": "langchain_community.agent_toolkits.json.base",
"create_openapi_agent": "langchain_community.agent_toolkits.openapi.base",
"create_pbi_agent": "langchain_community.agent_toolkits.powerbi.base",
"create_pbi_chat_agent": "langchain_community.agent_toolkits.powerbi.chat_base",
"create_spark_sql_agent": "langchain_community.agent_toolkits.spark_sql.base",
"create_sql_agent": "langchain_community.agent_toolkits.sql.base",
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@@ -1 +0,0 @@
"""AINetwork toolkit."""

View File

@@ -1,70 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Literal, Optional
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, model_validator
from langchain_community.tools.ainetwork.app import AINAppOps
from langchain_community.tools.ainetwork.owner import AINOwnerOps
from langchain_community.tools.ainetwork.rule import AINRuleOps
from langchain_community.tools.ainetwork.transfer import AINTransfer
from langchain_community.tools.ainetwork.utils import authenticate
from langchain_community.tools.ainetwork.value import AINValueOps
if TYPE_CHECKING:
from ain.ain import Ain
class AINetworkToolkit(BaseToolkit):
"""Toolkit for interacting with AINetwork Blockchain.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
See https://python.langchain.com/docs/security for more information.
Parameters:
network: Optional. The network to connect to. Default is "testnet".
Options are "mainnet" or "testnet".
interface: Optional. The interface to use. If not provided, will
attempt to authenticate with the network. Default is None.
"""
network: Optional[Literal["mainnet", "testnet"]] = "testnet"
interface: Optional[Ain] = None
@model_validator(mode="before")
@classmethod
def set_interface(cls, values: dict) -> Any:
"""Set the interface if not provided.
If the interface is not provided, attempt to authenticate with the
network using the network value provided.
Args:
values: The values to validate.
Returns:
The validated values.
"""
if not values.get("interface"):
values["interface"] = authenticate(network=values.get("network", "testnet"))
return values
model_config = ConfigDict(
arbitrary_types_allowed=True,
validate_default=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
AINAppOps(),
AINOwnerOps(),
AINRuleOps(),
AINTransfer(),
AINValueOps(),
]

View File

@@ -1,38 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.amadeus.closest_airport import AmadeusClosestAirport
from langchain_community.tools.amadeus.flight_search import AmadeusFlightSearch
from langchain_community.tools.amadeus.utils import authenticate
if TYPE_CHECKING:
from amadeus import Client
class AmadeusToolkit(BaseToolkit):
"""Toolkit for interacting with Amadeus which offers APIs for travel.
Parameters:
client: Optional. The Amadeus client. Default is None.
llm: Optional. The language model to use. Default is None.
"""
client: Client = Field(default_factory=authenticate)
llm: Optional[BaseLanguageModel] = Field(default=None)
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
AmadeusClosestAirport(llm=self.llm),
AmadeusFlightSearch(),
]

View File

@@ -1,31 +0,0 @@
from __future__ import annotations
from typing import List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.azure_ai_services import (
AzureAiServicesDocumentIntelligenceTool,
AzureAiServicesImageAnalysisTool,
AzureAiServicesSpeechToTextTool,
AzureAiServicesTextAnalyticsForHealthTool,
AzureAiServicesTextToSpeechTool,
)
class AzureAiServicesToolkit(BaseToolkit):
"""Toolkit for Azure AI Services."""
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
tools: List[BaseTool] = [
AzureAiServicesDocumentIntelligenceTool(), # type: ignore[call-arg]
AzureAiServicesImageAnalysisTool(),
AzureAiServicesSpeechToTextTool(), # type: ignore[call-arg]
AzureAiServicesTextToSpeechTool(), # type: ignore[call-arg]
AzureAiServicesTextAnalyticsForHealthTool(), # type: ignore[call-arg]
]
return tools

View File

@@ -1,34 +0,0 @@
from __future__ import annotations
import sys
from typing import List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.azure_cognitive_services import (
AzureCogsFormRecognizerTool,
AzureCogsImageAnalysisTool,
AzureCogsSpeech2TextTool,
AzureCogsText2SpeechTool,
AzureCogsTextAnalyticsHealthTool,
)
class AzureCognitiveServicesToolkit(BaseToolkit):
"""Toolkit for Azure Cognitive Services."""
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
tools: List[BaseTool] = [
AzureCogsFormRecognizerTool(), # type: ignore[call-arg]
AzureCogsSpeech2TextTool(), # type: ignore[call-arg]
AzureCogsText2SpeechTool(), # type: ignore[call-arg]
AzureCogsTextAnalyticsHealthTool(), # type: ignore[call-arg]
]
# TODO: Remove check once azure-ai-vision supports MacOS.
if sys.platform.startswith("linux") or sys.platform.startswith("win"):
tools.append(AzureCogsImageAnalysisTool()) # type: ignore[call-arg]
return tools

View File

@@ -1,5 +0,0 @@
"""Toolkits for agents."""
from langchain_core.tools.base import BaseToolkit
__all__ = ["BaseToolkit"]

View File

@@ -1 +0,0 @@
"""Apache Cassandra Toolkit."""

View File

@@ -1,37 +0,0 @@
"""Apache Cassandra Toolkit."""
from typing import List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.cassandra_database.tool import (
GetSchemaCassandraDatabaseTool,
GetTableDataCassandraDatabaseTool,
QueryCassandraDatabaseTool,
)
from langchain_community.utilities.cassandra_database import CassandraDatabase
class CassandraDatabaseToolkit(BaseToolkit):
"""Toolkit for interacting with an Apache Cassandra database.
Parameters:
db: CassandraDatabase. The Cassandra database to interact
with.
"""
db: CassandraDatabase = Field(exclude=True)
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
GetSchemaCassandraDatabaseTool(db=self.db),
QueryCassandraDatabaseTool(db=self.db),
GetTableDataCassandraDatabaseTool(db=self.db),
]

View File

@@ -1,120 +0,0 @@
from typing import Dict, List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.clickup.prompt import (
CLICKUP_FOLDER_CREATE_PROMPT,
CLICKUP_GET_ALL_TEAMS_PROMPT,
CLICKUP_GET_FOLDERS_PROMPT,
CLICKUP_GET_LIST_PROMPT,
CLICKUP_GET_SPACES_PROMPT,
CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
CLICKUP_GET_TASK_PROMPT,
CLICKUP_LIST_CREATE_PROMPT,
CLICKUP_TASK_CREATE_PROMPT,
CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
CLICKUP_UPDATE_TASK_PROMPT,
)
from langchain_community.tools.clickup.tool import ClickupAction
from langchain_community.utilities.clickup import ClickupAPIWrapper
class ClickupToolkit(BaseToolkit):
"""Clickup Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
See https://python.langchain.com/docs/security for more information.
Parameters:
tools: List[BaseTool]. The tools in the toolkit. Default is an empty list.
"""
tools: List[BaseTool] = []
@classmethod
def from_clickup_api_wrapper(
cls, clickup_api_wrapper: ClickupAPIWrapper
) -> "ClickupToolkit":
"""Create a ClickupToolkit from a ClickupAPIWrapper.
Args:
clickup_api_wrapper: ClickupAPIWrapper. The Clickup API wrapper.
Returns:
ClickupToolkit. The Clickup toolkit.
"""
operations: List[Dict] = [
{
"mode": "get_task",
"name": "Get task",
"description": CLICKUP_GET_TASK_PROMPT,
},
{
"mode": "get_task_attribute",
"name": "Get task attribute",
"description": CLICKUP_GET_TASK_ATTRIBUTE_PROMPT,
},
{
"mode": "get_teams",
"name": "Get Teams",
"description": CLICKUP_GET_ALL_TEAMS_PROMPT,
},
{
"mode": "create_task",
"name": "Create Task",
"description": CLICKUP_TASK_CREATE_PROMPT,
},
{
"mode": "create_list",
"name": "Create List",
"description": CLICKUP_LIST_CREATE_PROMPT,
},
{
"mode": "create_folder",
"name": "Create Folder",
"description": CLICKUP_FOLDER_CREATE_PROMPT,
},
{
"mode": "get_list",
"name": "Get all lists in the space",
"description": CLICKUP_GET_LIST_PROMPT,
},
{
"mode": "get_folders",
"name": "Get all folders in the workspace",
"description": CLICKUP_GET_FOLDERS_PROMPT,
},
{
"mode": "get_spaces",
"name": "Get all spaces in the workspace",
"description": CLICKUP_GET_SPACES_PROMPT,
},
{
"mode": "update_task",
"name": "Update task",
"description": CLICKUP_UPDATE_TASK_PROMPT,
},
{
"mode": "update_task_assignees",
"name": "Update task assignees",
"description": CLICKUP_UPDATE_TASK_ASSIGNEE_PROMPT,
},
]
tools = [
ClickupAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=clickup_api_wrapper,
)
for action in operations
]
return cls(tools=tools) # type: ignore[arg-type]
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -1 +0,0 @@
"""CogniSwitch Toolkit"""

View File

@@ -1,45 +0,0 @@
from typing import List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.cogniswitch.tool import (
CogniswitchKnowledgeRequest,
CogniswitchKnowledgeSourceFile,
CogniswitchKnowledgeSourceURL,
CogniswitchKnowledgeStatus,
)
class CogniswitchToolkit(BaseToolkit):
"""Toolkit for CogniSwitch.
Use the toolkit to get all the tools present in the Cogniswitch and
use them to interact with your knowledge.
Parameters:
cs_token: str. The Cogniswitch token.
OAI_token: str. The OpenAI API token.
apiKey: str. The Cogniswitch OAuth token.
"""
cs_token: str
OAI_token: str
apiKey: str
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
CogniswitchKnowledgeStatus(
cs_token=self.cs_token, OAI_token=self.OAI_token, apiKey=self.apiKey
),
CogniswitchKnowledgeRequest(
cs_token=self.cs_token, OAI_token=self.OAI_token, apiKey=self.apiKey
),
CogniswitchKnowledgeSourceFile(
cs_token=self.cs_token, OAI_token=self.OAI_token, apiKey=self.apiKey
),
CogniswitchKnowledgeSourceURL(
cs_token=self.cs_token, OAI_token=self.OAI_token, apiKey=self.apiKey
),
]

View File

@@ -1,7 +0,0 @@
"""
This module contains the ConneryToolkit.
"""
from .toolkit import ConneryToolkit
__all__ = ["ConneryToolkit"]

View File

@@ -1,60 +0,0 @@
from typing import Any, List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import model_validator
from langchain_community.tools.connery import ConneryService
class ConneryToolkit(BaseToolkit):
"""
Toolkit with a list of Connery Actions as tools.
Parameters:
tools (List[BaseTool]): The list of Connery Actions.
"""
tools: List[BaseTool]
def get_tools(self) -> List[BaseTool]:
"""
Returns the list of Connery Actions.
"""
return self.tools
@model_validator(mode="before")
@classmethod
def validate_attributes(cls, values: dict) -> Any:
"""
Validate the attributes of the ConneryToolkit class.
Args:
values (dict): The arguments to validate.
Returns:
dict: The validated arguments.
Raises:
ValueError: If the 'tools' attribute is not set
"""
if not values.get("tools"):
raise ValueError("The attribute 'tools' must be set.")
return values
@classmethod
def create_instance(cls, connery_service: ConneryService) -> "ConneryToolkit":
"""
Creates a Connery Toolkit using a Connery Service.
Parameters:
connery_service (ConneryService): The Connery Service
to get the list of Connery Actions.
Returns:
ConneryToolkit: The Connery Toolkit.
"""
instance = cls(tools=connery_service.list_actions()) # type: ignore[arg-type]
return instance

View File

@@ -1,26 +0,0 @@
from pathlib import Path
from typing import Any
from langchain_core._api.path import as_import_path
def __getattr__(name: str) -> Any:
"""Get attr name."""
if name == "create_csv_agent":
# Get directory of langchain package
HERE = Path(__file__).parents[3]
here = as_import_path(Path(__file__).parent, relative_to=HERE)
old_path = "langchain." + here + "." + name
new_path = "langchain_experimental." + here + "." + name
raise ImportError(
"This agent has been moved to langchain experiment. "
"This agent relies on python REPL tool under the hood, so to use it "
"safely please sandbox the python REPL. "
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
"and https://github.com/langchain-ai/langchain/discussions/11680"
"To keep using this code as is, install langchain experimental and "
f"update your import statement from:\n `{old_path}` to `{new_path}`."
)
raise AttributeError(f"{name} does not exist")

View File

@@ -1,7 +0,0 @@
"""Local file management toolkit."""
from langchain_community.agent_toolkits.file_management.toolkit import (
FileManagementToolkit,
)
__all__ = ["FileManagementToolkit"]

View File

@@ -1,88 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type
from langchain_core.tools import BaseTool, BaseToolkit
from langchain_core.utils.pydantic import get_fields
from pydantic import model_validator
from langchain_community.tools.file_management.copy import CopyFileTool
from langchain_community.tools.file_management.delete import DeleteFileTool
from langchain_community.tools.file_management.file_search import FileSearchTool
from langchain_community.tools.file_management.list_dir import ListDirectoryTool
from langchain_community.tools.file_management.move import MoveFileTool
from langchain_community.tools.file_management.read import ReadFileTool
from langchain_community.tools.file_management.write import WriteFileTool
_FILE_TOOLS: List[Type[BaseTool]] = [
CopyFileTool,
DeleteFileTool,
FileSearchTool,
MoveFileTool,
ReadFileTool,
WriteFileTool,
ListDirectoryTool,
]
_FILE_TOOLS_MAP: Dict[str, Type[BaseTool]] = {
get_fields(tool_cls)["name"].default: tool_cls for tool_cls in _FILE_TOOLS
}
class FileManagementToolkit(BaseToolkit):
"""Toolkit for interacting with local files.
*Security Notice*: This toolkit provides methods to interact with local files.
If providing this toolkit to an agent on an LLM, ensure you scope
the agent's permissions to only include the necessary permissions
to perform the desired operations.
By **default** the agent will have access to all files within
the root dir and will be able to Copy, Delete, Move, Read, Write
and List files in that directory.
Consider the following:
- Limit access to particular directories using `root_dir`.
- Use filesystem permissions to restrict access and permissions to only
the files and directories required by the agent.
- Limit the tools available to the agent to only the file operations
necessary for the agent's intended use.
- Sandbox the agent by running it in a container.
See https://python.langchain.com/docs/security for more information.
Parameters:
root_dir: Optional. The root directory to perform file operations.
If not provided, file operations are performed relative to the current
working directory.
selected_tools: Optional. The tools to include in the toolkit. If not
provided, all tools are included.
"""
root_dir: Optional[str] = None
"""If specified, all file operations are made relative to root_dir."""
selected_tools: Optional[List[str]] = None
"""If provided, only provide the selected tools. Defaults to all."""
@model_validator(mode="before")
@classmethod
def validate_tools(cls, values: dict) -> Any:
selected_tools = values.get("selected_tools") or []
for tool_name in selected_tools:
if tool_name not in _FILE_TOOLS_MAP:
raise ValueError(
f"File Tool of name {tool_name} not supported."
f" Permitted tools: {list(_FILE_TOOLS_MAP)}"
)
return values
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
allowed_tools = self.selected_tools or _FILE_TOOLS_MAP
tools: List[BaseTool] = []
for tool in allowed_tools:
tool_cls = _FILE_TOOLS_MAP[tool]
tools.append(tool_cls(root_dir=self.root_dir))
return tools
__all__ = ["FileManagementToolkit"]

View File

@@ -1 +0,0 @@
"""financial datasets toolkit."""

View File

@@ -1,44 +0,0 @@
from __future__ import annotations
from typing import List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.financial_datasets.balance_sheets import BalanceSheets
from langchain_community.tools.financial_datasets.cash_flow_statements import (
CashFlowStatements,
)
from langchain_community.tools.financial_datasets.income_statements import (
IncomeStatements,
)
from langchain_community.utilities.financial_datasets import FinancialDatasetsAPIWrapper
class FinancialDatasetsToolkit(BaseToolkit):
"""Toolkit for interacting with financialdatasets.ai.
Parameters:
api_wrapper: The FinancialDatasets API Wrapper.
"""
api_wrapper: FinancialDatasetsAPIWrapper = Field(
default_factory=FinancialDatasetsAPIWrapper
)
def __init__(self, api_wrapper: FinancialDatasetsAPIWrapper):
super().__init__()
self.api_wrapper = api_wrapper
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
BalanceSheets(api_wrapper=self.api_wrapper),
CashFlowStatements(api_wrapper=self.api_wrapper),
IncomeStatements(api_wrapper=self.api_wrapper),
]

View File

@@ -1 +0,0 @@
"""GitHub Toolkit."""

View File

@@ -1,479 +0,0 @@
"""GitHub Toolkit."""
from typing import Dict, List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import BaseModel, Field
from langchain_community.tools.github.prompt import (
COMMENT_ON_ISSUE_PROMPT,
CREATE_BRANCH_PROMPT,
CREATE_FILE_PROMPT,
CREATE_PULL_REQUEST_PROMPT,
CREATE_REVIEW_REQUEST_PROMPT,
DELETE_FILE_PROMPT,
GET_FILES_FROM_DIRECTORY_PROMPT,
GET_ISSUE_PROMPT,
GET_ISSUES_PROMPT,
GET_LATEST_RELEASE_PROMPT,
GET_PR_PROMPT,
GET_RELEASE_PROMPT,
GET_RELEASES_PROMPT,
LIST_BRANCHES_IN_REPO_PROMPT,
LIST_PRS_PROMPT,
LIST_PULL_REQUEST_FILES,
OVERVIEW_EXISTING_FILES_BOT_BRANCH,
OVERVIEW_EXISTING_FILES_IN_MAIN,
READ_FILE_PROMPT,
SEARCH_CODE_PROMPT,
SEARCH_ISSUES_AND_PRS_PROMPT,
SET_ACTIVE_BRANCH_PROMPT,
UPDATE_FILE_PROMPT,
)
from langchain_community.tools.github.tool import GitHubAction
from langchain_community.utilities.github import GitHubAPIWrapper
class NoInput(BaseModel):
"""Schema for operations that do not require any input."""
no_input: str = Field("", description="No input required, e.g. `` (empty string).")
class GetIssue(BaseModel):
"""Schema for operations that require an issue number as input."""
issue_number: int = Field(0, description="Issue number as an integer, e.g. `42`")
class CommentOnIssue(BaseModel):
"""Schema for operations that require a comment as input."""
input: str = Field(..., description="Follow the required formatting.")
class GetPR(BaseModel):
"""Schema for operations that require a PR number as input."""
pr_number: int = Field(0, description="The PR number as an integer, e.g. `12`")
class CreatePR(BaseModel):
"""Schema for operations that require a PR title and body as input."""
formatted_pr: str = Field(..., description="Follow the required formatting.")
class CreateFile(BaseModel):
"""Schema for operations that require a file path and content as input."""
formatted_file: str = Field(..., description="Follow the required formatting.")
class ReadFile(BaseModel):
"""Schema for operations that require a file path as input."""
formatted_filepath: str = Field(
...,
description=(
"The full file path of the file you would like to read where the "
"path must NOT start with a slash, e.g. `some_dir/my_file.py`."
),
)
class UpdateFile(BaseModel):
"""Schema for operations that require a file path and content as input."""
formatted_file_update: str = Field(
..., description="Strictly follow the provided rules."
)
class DeleteFile(BaseModel):
"""Schema for operations that require a file path as input."""
formatted_filepath: str = Field(
...,
description=(
"The full file path of the file you would like to delete"
" where the path must NOT start with a slash, e.g."
" `some_dir/my_file.py`. Only input a string,"
" not the param name."
),
)
class DirectoryPath(BaseModel):
"""Schema for operations that require a directory path as input."""
input: str = Field(
"",
description=(
"The path of the directory, e.g. `some_dir/inner_dir`."
" Only input a string, do not include the parameter name."
),
)
class BranchName(BaseModel):
"""Schema for operations that require a branch name as input."""
branch_name: str = Field(
..., description="The name of the branch, e.g. `my_branch`."
)
class SearchCode(BaseModel):
"""Schema for operations that require a search query as input."""
search_query: str = Field(
...,
description=(
"A keyword-focused natural language search"
"query for code, e.g. `MyFunctionName()`."
),
)
class CreateReviewRequest(BaseModel):
"""Schema for operations that require a username as input."""
username: str = Field(
...,
description="GitHub username of the user being requested, e.g. `my_username`.",
)
class SearchIssuesAndPRs(BaseModel):
"""Schema for operations that require a search query as input."""
search_query: str = Field(
...,
description="Natural language search query, e.g. `My issue title or topic`.",
)
class TagName(BaseModel):
"""Schema for operations that require a tag name as input."""
tag_name: str = Field(
...,
description="The tag name of the release, e.g. `v1.0.0`.",
)
class GitHubToolkit(BaseToolkit):
"""GitHub Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
For example, this toolkit can be used to create issues, pull requests,
and comments on GitHub.
See [Security](https://python.langchain.com/docs/security) for more information.
Setup:
See detailed installation instructions here:
https://python.langchain.com/docs/integrations/tools/github/#installation
You will need to install ``pygithub`` and set the following environment
variables:
.. code-block:: bash
pip install -U pygithub
export GITHUB_APP_ID="your-app-id"
export GITHUB_APP_PRIVATE_KEY="path-to-private-key"
export GITHUB_REPOSITORY="your-github-repository"
Instantiate:
.. code-block:: python
from langchain_community.agent_toolkits.github.toolkit import GitHubToolkit
from langchain_community.utilities.github import GitHubAPIWrapper
github = GitHubAPIWrapper()
toolkit = GitHubToolkit.from_github_api_wrapper(github)
Tools:
.. code-block:: python
tools = toolkit.get_tools()
for tool in tools:
print(tool.name)
.. code-block:: none
Get Issues
Get Issue
Comment on Issue
List open pull requests (PRs)
Get Pull Request
Overview of files included in PR
Create Pull Request
List Pull Requests' Files
Create File
Read File
Update File
Delete File
Overview of existing files in Main branch
Overview of files in current working branch
List branches in this repository
Set active branch
Create a new branch
Get files from a directory
Search issues and pull requests
Search code
Create review request
Include release tools:
By default, the toolkit does not include release-related tools.
You can include them by setting ``include_release_tools=True`` when
initializing the toolkit:
.. code-block:: python
toolkit = GitHubToolkit.from_github_api_wrapper(
github, include_release_tools=True
)
Setting ``include_release_tools=True`` will include the following tools:
.. code-block:: none
Get latest release
Get releases
Get release
Use within an agent:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
# Select example tool
tools = [tool for tool in toolkit.get_tools() if tool.name == "Get Issue"]
assert len(tools) == 1
tools[0].name = "get_issue"
llm = ChatOpenAI(model="gpt-4o-mini")
agent_executor = create_react_agent(llm, tools)
example_query = "What is the title of issue 24888?"
events = agent_executor.stream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
for event in events:
event["messages"][-1].pretty_print()
.. code-block:: none
================================[1m Human Message [0m=================================
What is the title of issue 24888?
==================================[1m Ai Message [0m==================================
Tool Calls:
get_issue (call_iSYJVaM7uchfNHOMJoVPQsOi)
Call ID: call_iSYJVaM7uchfNHOMJoVPQsOi
Args:
issue_number: 24888
=================================[1m Tool Message [0m=================================
Name: get_issue
{"number": 24888, "title": "Standardize KV-Store Docs", "body": "..."
==================================[1m Ai Message [0m==================================
The title of issue 24888 is "Standardize KV-Store Docs".
Parameters:
tools: List[BaseTool]. The tools in the toolkit. Default is an empty list.
""" # noqa: E501
tools: List[BaseTool] = []
@classmethod
def from_github_api_wrapper(
cls, github_api_wrapper: GitHubAPIWrapper, include_release_tools: bool = False
) -> "GitHubToolkit":
"""Create a GitHubToolkit from a GitHubAPIWrapper.
Args:
github_api_wrapper: GitHubAPIWrapper. The GitHub API wrapper.
include_release_tools: bool. Whether to include release-related tools.
Defaults to False.
Returns:
GitHubToolkit. The GitHub toolkit.
"""
operations: List[Dict] = [
{
"mode": "get_issues",
"name": "Get Issues",
"description": GET_ISSUES_PROMPT,
"args_schema": NoInput,
},
{
"mode": "get_issue",
"name": "Get Issue",
"description": GET_ISSUE_PROMPT,
"args_schema": GetIssue,
},
{
"mode": "comment_on_issue",
"name": "Comment on Issue",
"description": COMMENT_ON_ISSUE_PROMPT,
"args_schema": CommentOnIssue,
},
{
"mode": "list_open_pull_requests",
"name": "List open pull requests (PRs)",
"description": LIST_PRS_PROMPT,
"args_schema": NoInput,
},
{
"mode": "get_pull_request",
"name": "Get Pull Request",
"description": GET_PR_PROMPT,
"args_schema": GetPR,
},
{
"mode": "list_pull_request_files",
"name": "Overview of files included in PR",
"description": LIST_PULL_REQUEST_FILES,
"args_schema": GetPR,
},
{
"mode": "create_pull_request",
"name": "Create Pull Request",
"description": CREATE_PULL_REQUEST_PROMPT,
"args_schema": CreatePR,
},
{
"mode": "list_pull_request_files",
"name": "List Pull Requests' Files",
"description": LIST_PULL_REQUEST_FILES,
"args_schema": GetPR,
},
{
"mode": "create_file",
"name": "Create File",
"description": CREATE_FILE_PROMPT,
"args_schema": CreateFile,
},
{
"mode": "read_file",
"name": "Read File",
"description": READ_FILE_PROMPT,
"args_schema": ReadFile,
},
{
"mode": "update_file",
"name": "Update File",
"description": UPDATE_FILE_PROMPT,
"args_schema": UpdateFile,
},
{
"mode": "delete_file",
"name": "Delete File",
"description": DELETE_FILE_PROMPT,
"args_schema": DeleteFile,
},
{
"mode": "list_files_in_main_branch",
"name": "Overview of existing files in Main branch",
"description": OVERVIEW_EXISTING_FILES_IN_MAIN,
"args_schema": NoInput,
},
{
"mode": "list_files_in_bot_branch",
"name": "Overview of files in current working branch",
"description": OVERVIEW_EXISTING_FILES_BOT_BRANCH,
"args_schema": NoInput,
},
{
"mode": "list_branches_in_repo",
"name": "List branches in this repository",
"description": LIST_BRANCHES_IN_REPO_PROMPT,
"args_schema": NoInput,
},
{
"mode": "set_active_branch",
"name": "Set active branch",
"description": SET_ACTIVE_BRANCH_PROMPT,
"args_schema": BranchName,
},
{
"mode": "create_branch",
"name": "Create a new branch",
"description": CREATE_BRANCH_PROMPT,
"args_schema": BranchName,
},
{
"mode": "get_files_from_directory",
"name": "Get files from a directory",
"description": GET_FILES_FROM_DIRECTORY_PROMPT,
"args_schema": DirectoryPath,
},
{
"mode": "search_issues_and_prs",
"name": "Search issues and pull requests",
"description": SEARCH_ISSUES_AND_PRS_PROMPT,
"args_schema": SearchIssuesAndPRs,
},
{
"mode": "search_code",
"name": "Search code",
"description": SEARCH_CODE_PROMPT,
"args_schema": SearchCode,
},
{
"mode": "create_review_request",
"name": "Create review request",
"description": CREATE_REVIEW_REQUEST_PROMPT,
"args_schema": CreateReviewRequest,
},
]
release_operations: List[Dict] = [
{
"mode": "get_latest_release",
"name": "Get latest release",
"description": GET_LATEST_RELEASE_PROMPT,
"args_schema": NoInput,
},
{
"mode": "get_releases",
"name": "Get releases",
"description": GET_RELEASES_PROMPT,
"args_schema": NoInput,
},
{
"mode": "get_release",
"name": "Get release",
"description": GET_RELEASE_PROMPT,
"args_schema": TagName,
},
]
operations = operations + (release_operations if include_release_tools else [])
tools = [
GitHubAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=github_api_wrapper,
args_schema=action.get("args_schema", None),
)
for action in operations
]
return cls(tools=tools) # type: ignore[arg-type]
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -1 +0,0 @@
"""GitLab Toolkit."""

View File

@@ -1,170 +0,0 @@
"""GitLab Toolkit."""
from typing import Dict, List, Optional
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.gitlab.prompt import (
COMMENT_ON_ISSUE_PROMPT,
CREATE_FILE_PROMPT,
CREATE_PULL_REQUEST_PROMPT,
CREATE_REPO_BRANCH,
DELETE_FILE_PROMPT,
GET_ISSUE_PROMPT,
GET_ISSUES_PROMPT,
GET_REPO_FILES_FROM_DIRECTORY,
GET_REPO_FILES_IN_BOT_BRANCH,
GET_REPO_FILES_IN_MAIN,
LIST_REPO_BRANCES,
READ_FILE_PROMPT,
SET_ACTIVE_BRANCH,
UPDATE_FILE_PROMPT,
)
from langchain_community.tools.gitlab.tool import GitLabAction
from langchain_community.utilities.gitlab import GitLabAPIWrapper
# only include a subset of tools by default to avoid a breaking change, where
# new tools are added to the toolkit and the user's code breaks because of
# the new tools
DEFAULT_INCLUDED_TOOLS = [
"get_issues",
"get_issue",
"comment_on_issue",
"create_pull_request",
"create_file",
"read_file",
"update_file",
"delete_file",
]
class GitLabToolkit(BaseToolkit):
"""GitLab Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
For example, this toolkit can be used to create issues, pull requests,
and comments on GitLab.
See https://python.langchain.com/docs/security for more information.
Parameters:
tools: List[BaseTool]. The tools in the toolkit. Default is an empty list.
"""
tools: List[BaseTool] = []
@classmethod
def from_gitlab_api_wrapper(
cls,
gitlab_api_wrapper: GitLabAPIWrapper,
*,
included_tools: Optional[List[str]] = None,
) -> "GitLabToolkit":
"""Create a GitLabToolkit from a GitLabAPIWrapper.
Args:
gitlab_api_wrapper: GitLabAPIWrapper. The GitLab API wrapper.
Returns:
GitLabToolkit. The GitLab toolkit.
"""
tools_to_include = (
included_tools if included_tools is not None else DEFAULT_INCLUDED_TOOLS
)
operations: List[Dict] = [
{
"mode": "get_issues",
"name": "Get Issues",
"description": GET_ISSUES_PROMPT,
},
{
"mode": "get_issue",
"name": "Get Issue",
"description": GET_ISSUE_PROMPT,
},
{
"mode": "comment_on_issue",
"name": "Comment on Issue",
"description": COMMENT_ON_ISSUE_PROMPT,
},
{
"mode": "create_pull_request",
"name": "Create Pull Request",
"description": CREATE_PULL_REQUEST_PROMPT,
},
{
"mode": "create_file",
"name": "Create File",
"description": CREATE_FILE_PROMPT,
},
{
"mode": "read_file",
"name": "Read File",
"description": READ_FILE_PROMPT,
},
{
"mode": "update_file",
"name": "Update File",
"description": UPDATE_FILE_PROMPT,
},
{
"mode": "delete_file",
"name": "Delete File",
"description": DELETE_FILE_PROMPT,
},
{
"mode": "create_branch",
"name": "Create a new branch",
"description": CREATE_REPO_BRANCH,
},
{
"mode": "list_branches_in_repo",
"name": "Get the list of branches",
"description": LIST_REPO_BRANCES,
},
{
"mode": "set_active_branch",
"name": "Change the active branch",
"description": SET_ACTIVE_BRANCH,
},
{
"mode": "list_files_in_main_branch",
"name": "Overview of existing files in Main branch",
"description": GET_REPO_FILES_IN_MAIN,
},
{
"mode": "list_files_in_bot_branch",
"name": "Overview of files in current working branch",
"description": GET_REPO_FILES_IN_BOT_BRANCH,
},
{
"mode": "list_files_from_directory",
"name": "Overview of files in current working branch from a specific path", # noqa: E501
"description": GET_REPO_FILES_FROM_DIRECTORY,
},
]
operations_filtered = [
operation
for operation in operations
if operation["mode"] in tools_to_include
]
tools = [
GitLabAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=gitlab_api_wrapper,
)
for action in operations_filtered
]
return cls(tools=tools) # type: ignore[arg-type]
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -1 +0,0 @@
"""Gmail toolkit."""

View File

@@ -1,132 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.gmail.create_draft import GmailCreateDraft
from langchain_community.tools.gmail.get_message import GmailGetMessage
from langchain_community.tools.gmail.get_thread import GmailGetThread
from langchain_community.tools.gmail.search import GmailSearch
from langchain_community.tools.gmail.send_message import GmailSendMessage
from langchain_community.tools.gmail.utils import build_resource_service
if TYPE_CHECKING:
# This is for linting and IDE typehints
from googleapiclient.discovery import Resource
else:
try:
# We do this so pydantic can resolve the types when instantiating
from googleapiclient.discovery import Resource
except ImportError:
pass
SCOPES = ["https://mail.google.com/"]
class GmailToolkit(BaseToolkit):
"""Toolkit for interacting with Gmail.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
For example, this toolkit can be used to send emails on behalf of the
associated account.
See https://python.langchain.com/docs/security for more information.
Setup:
You will need a Google credentials.json file to use this toolkit.
See instructions here: https://python.langchain.com/docs/integrations/tools/gmail/#setup
Key init args:
api_resource: Optional. The Google API resource. Default is None.
Instantiate:
.. code-block:: python
from langchain_google_community import GmailToolkit
toolkit = GmailToolkit()
Tools:
.. code-block:: python
toolkit.get_tools()
.. code-block:: none
[GmailCreateDraft(api_resource=<googleapiclient.discovery.Resource object at 0x1094509d0>),
GmailSendMessage(api_resource=<googleapiclient.discovery.Resource object at 0x1094509d0>),
GmailSearch(api_resource=<googleapiclient.discovery.Resource object at 0x1094509d0>),
GmailGetMessage(api_resource=<googleapiclient.discovery.Resource object at 0x1094509d0>),
GmailGetThread(api_resource=<googleapiclient.discovery.Resource object at 0x1094509d0>)]
Use within an agent:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
llm = ChatOpenAI(model="gpt-4o-mini")
agent_executor = create_react_agent(llm, tools)
example_query = "Draft an email to fake@fake.com thanking them for coffee."
events = agent_executor.stream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
for event in events:
event["messages"][-1].pretty_print()
.. code-block:: none
================================[1m Human Message [0m=================================
Draft an email to fake@fake.com thanking them for coffee.
==================================[1m Ai Message [0m==================================
Tool Calls:
create_gmail_draft (call_slGkYKZKA6h3Mf1CraUBzs6M)
Call ID: call_slGkYKZKA6h3Mf1CraUBzs6M
Args:
message: Dear Fake,
I wanted to take a moment to thank you for the coffee yesterday. It was a pleasure catching up with you. Let's do it again soon!
Best regards,
[Your Name]
to: ['fake@fake.com']
subject: Thank You for the Coffee
=================================[1m Tool Message [0m=================================
Name: create_gmail_draft
Draft created. Draft Id: r-7233782721440261513
==================================[1m Ai Message [0m==================================
I have drafted an email to fake@fake.com thanking them for the coffee. You can review and send it from your email draft with the subject "Thank You for the Coffee".
Parameters:
api_resource: Optional. The Google API resource. Default is None.
""" # noqa: E501
api_resource: Resource = Field(default_factory=build_resource_service)
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
GmailCreateDraft(api_resource=self.api_resource),
GmailSendMessage(api_resource=self.api_resource),
GmailSearch(api_resource=self.api_resource),
GmailGetMessage(api_resource=self.api_resource),
GmailGetThread(api_resource=self.api_resource),
]

View File

@@ -1 +0,0 @@
"""Jira Toolkit."""

View File

@@ -1,83 +0,0 @@
from typing import Dict, List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.jira.prompt import (
JIRA_CATCH_ALL_PROMPT,
JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
JIRA_GET_ALL_PROJECTS_PROMPT,
JIRA_ISSUE_CREATE_PROMPT,
JIRA_JQL_PROMPT,
)
from langchain_community.tools.jira.tool import JiraAction
from langchain_community.utilities.jira import JiraAPIWrapper
class JiraToolkit(BaseToolkit):
"""Jira Toolkit.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
See https://python.langchain.com/docs/security for more information.
Parameters:
tools: List[BaseTool]. The tools in the toolkit. Default is an empty list.
"""
tools: List[BaseTool] = []
@classmethod
def from_jira_api_wrapper(cls, jira_api_wrapper: JiraAPIWrapper) -> "JiraToolkit":
"""Create a JiraToolkit from a JiraAPIWrapper.
Args:
jira_api_wrapper: JiraAPIWrapper. The Jira API wrapper.
Returns:
JiraToolkit. The Jira toolkit.
"""
operations: List[Dict] = [
{
"mode": "jql",
"name": "jql_query",
"description": JIRA_JQL_PROMPT,
},
{
"mode": "get_projects",
"name": "get_projects",
"description": JIRA_GET_ALL_PROJECTS_PROMPT,
},
{
"mode": "create_issue",
"name": "create_issue",
"description": JIRA_ISSUE_CREATE_PROMPT,
},
{
"mode": "other",
"name": "catch_all_jira_api",
"description": JIRA_CATCH_ALL_PROMPT,
},
{
"mode": "create_page",
"name": "create_confluence_page",
"description": JIRA_CONFLUENCE_PAGE_CREATE_PROMPT,
},
]
tools = [
JiraAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=jira_api_wrapper,
)
for action in operations
]
return cls(tools=tools) # type: ignore[arg-type]
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -1 +0,0 @@
"""Json agent."""

View File

@@ -1,76 +0,0 @@
"""Json agent."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_community.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain_community.agent_toolkits.json.toolkit import JsonToolkit
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
def create_json_agent(
llm: BaseLanguageModel,
toolkit: JsonToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = JSON_PREFIX,
suffix: str = JSON_SUFFIX,
format_instructions: Optional[str] = None,
input_variables: Optional[List[str]] = None,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a json agent from an LLM and tools.
Args:
llm: The language model to use.
toolkit: The toolkit to use.
callback_manager: The callback manager to use. Default is None.
prefix: The prefix to use. Default is JSON_PREFIX.
suffix: The suffix to use. Default is JSON_SUFFIX.
format_instructions: The format instructions to use. Default is None.
input_variables: The input variables to use. Default is None.
verbose: Whether to print verbose output. Default is False.
agent_executor_kwargs: Optional additional arguments for the agent executor.
kwargs: Additional arguments for the agent.
Returns:
The agent executor.
"""
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
tools = toolkit.get_tools()
prompt_params = (
{"format_instructions": format_instructions}
if format_instructions is not None
else {}
)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
**prompt_params,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@@ -1,25 +0,0 @@
# flake8: noqa
JSON_PREFIX = """You are an agent designed to interact with JSON.
Your goal is to return a final answer by interacting with the JSON.
You have access to the following tools which help you learn more about the JSON you are interacting with.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
Do not make up any information that is not contained in the JSON.
Your input to the tools should be in the form of `data["key"][0]` where `data` is the JSON blob you are interacting with, and the syntax used is Python.
You should only use keys that you know for a fact exist. You must validate that a key exists by seeing it previously when calling `json_spec_list_keys`.
If you have not seen a key in one of those responses, you cannot use it.
You should only add one key at a time to the path. You cannot add multiple keys at once.
If you encounter a "KeyError", go back to the previous key, look at the available keys, and try again.
If the question does not seem to be related to the JSON, just return "I don't know" as the answer.
Always begin your interaction with the `json_spec_list_keys` tool with input "data" to see what keys exist in the JSON.
Note that sometimes the value at a given path is large. In this case, you will get an error "Value is a large dictionary, should explore its keys directly".
In this case, you should ALWAYS follow up by using the `json_spec_list_keys` tool to see what keys exist at that path.
Do not simply refer the user to the JSON or a section of the JSON, as this is not a valid answer. Keep digging until you find the answer and explicitly return it.
"""
JSON_SUFFIX = """Begin!"
Question: {input}
Thought: I should look at the keys that exist in data to see what I have access to
{agent_scratchpad}"""

View File

@@ -1,29 +0,0 @@
from __future__ import annotations
from typing import List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.json.tool import (
JsonGetValueTool,
JsonListKeysTool,
JsonSpec,
)
class JsonToolkit(BaseToolkit):
"""Toolkit for interacting with a JSON spec.
Parameters:
spec: The JSON spec.
"""
spec: JsonSpec
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
JsonListKeysTool(spec=self.spec),
JsonGetValueTool(spec=self.spec),
]

View File

@@ -1,771 +0,0 @@
# flake8: noqa
"""Tools provide access to various resources and services.
LangChain has a large ecosystem of integrations with various external resources
like local and remote file systems, APIs and databases.
These integrations allow developers to create versatile applications that combine the
power of LLMs with the ability to access, interact with and manipulate external
resources.
When developing an application, developers should inspect the capabilities and
permissions of the tools that underlie the given agent toolkit, and determine
whether permissions of the given toolkit are appropriate for the application.
See [Security](https://python.langchain.com/docs/security) for more information.
"""
import warnings
from typing import Any, Dict, List, Optional, Callable, Tuple
from mypy_extensions import Arg, KwArg
from langchain_community.tools.arxiv.tool import ArxivQueryRun
from langchain_community.tools.bing_search.tool import BingSearchRun
from langchain_community.tools.dataforseo_api_search import DataForSeoAPISearchResults
from langchain_community.tools.dataforseo_api_search import DataForSeoAPISearchRun
from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchRun
from langchain_community.tools.eleven_labs.text2speech import ElevenLabsText2SpeechTool
from langchain_community.tools.file_management import ReadFileTool
from langchain_community.tools.golden_query.tool import GoldenQueryRun
from langchain_community.tools.google_cloud.texttospeech import (
GoogleCloudTextToSpeechTool,
)
from langchain_community.tools.google_finance.tool import GoogleFinanceQueryRun
from langchain_community.tools.google_jobs.tool import GoogleJobsQueryRun
from langchain_community.tools.google_lens.tool import GoogleLensQueryRun
from langchain_community.tools.google_scholar.tool import GoogleScholarQueryRun
from langchain_community.tools.google_search.tool import (
GoogleSearchResults,
GoogleSearchRun,
)
from langchain_community.tools.google_serper.tool import (
GoogleSerperResults,
GoogleSerperRun,
)
from langchain_community.tools.google_trends.tool import GoogleTrendsQueryRun
from langchain_community.tools.graphql.tool import BaseGraphQLTool
from langchain_community.tools.human.tool import HumanInputRun
from langchain_community.tools.memorize.tool import Memorize
from langchain_community.tools.merriam_webster.tool import MerriamWebsterQueryRun
from langchain_community.tools.metaphor_search.tool import MetaphorSearchResults
from langchain_community.tools.openweathermap.tool import OpenWeatherMapQueryRun
from langchain_community.tools.pubmed.tool import PubmedQueryRun
from langchain_community.tools.reddit_search.tool import RedditSearchRun
from langchain_community.tools.requests.tool import (
RequestsDeleteTool,
RequestsGetTool,
RequestsPatchTool,
RequestsPostTool,
RequestsPutTool,
)
from langchain_community.tools.scenexplain.tool import SceneXplainTool
from langchain_community.tools.searchapi.tool import SearchAPIResults, SearchAPIRun
from langchain_community.tools.searx_search.tool import (
SearxSearchResults,
SearxSearchRun,
)
from langchain_community.tools.shell.tool import ShellTool
from langchain_community.tools.sleep.tool import SleepTool
from langchain_community.tools.stackexchange.tool import StackExchangeTool
from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
from langchain_community.tools.wolfram_alpha.tool import WolframAlphaQueryRun
from langchain_community.utilities.arxiv import ArxivAPIWrapper
from langchain_community.utilities.awslambda import LambdaWrapper
from langchain_community.utilities.bing_search import BingSearchAPIWrapper
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
from langchain_community.utilities.dataforseo_api_search import DataForSeoAPIWrapper
from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from langchain_community.utilities.golden_query import GoldenQueryAPIWrapper
from langchain_community.utilities.google_books import GoogleBooksAPIWrapper
from langchain_community.utilities.google_finance import GoogleFinanceAPIWrapper
from langchain_community.utilities.google_jobs import GoogleJobsAPIWrapper
from langchain_community.utilities.google_lens import GoogleLensAPIWrapper
from langchain_community.utilities.google_scholar import GoogleScholarAPIWrapper
from langchain_community.utilities.google_search import GoogleSearchAPIWrapper
from langchain_community.utilities.google_serper import GoogleSerperAPIWrapper
from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper
from langchain_community.utilities.graphql import GraphQLAPIWrapper
from langchain_community.utilities.merriam_webster import MerriamWebsterAPIWrapper
from langchain_community.utilities.metaphor_search import MetaphorSearchAPIWrapper
from langchain_community.utilities.openweathermap import OpenWeatherMapAPIWrapper
from langchain_community.utilities.pubmed import PubMedAPIWrapper
from langchain_community.utilities.reddit_search import RedditSearchAPIWrapper
from langchain_community.utilities.requests import TextRequestsWrapper
from langchain_community.utilities.searchapi import SearchApiAPIWrapper
from langchain_community.utilities.searx_search import SearxSearchWrapper
from langchain_community.utilities.serpapi import SerpAPIWrapper
from langchain_community.utilities.stackexchange import StackExchangeAPIWrapper
from langchain_community.utilities.twilio import TwilioAPIWrapper
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
from langchain_community.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool, Tool
def _get_tools_requests_get() -> BaseTool:
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsGetTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_tools_requests_post() -> BaseTool:
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsPostTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_tools_requests_patch() -> BaseTool:
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsPatchTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_tools_requests_put() -> BaseTool:
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsPutTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_tools_requests_delete() -> BaseTool:
# Dangerous requests are allowed here, because there's another flag that the user
# has to provide in order to actually opt in.
# This is a private function and should not be used directly.
return RequestsDeleteTool(
requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True
)
def _get_terminal() -> BaseTool:
return ShellTool()
def _get_sleep() -> BaseTool:
return SleepTool()
_BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
"sleep": _get_sleep,
}
DANGEROUS_TOOLS = {
# Tools that contain some level of risk.
# Please use with caution and read the documentation of these tools
# to understand the risks and how to mitigate them.
# Refer to https://python.langchain.com/docs/security
# for more information.
"requests": _get_tools_requests_get, # preserved for backwards compatibility
"requests_get": _get_tools_requests_get,
"requests_post": _get_tools_requests_post,
"requests_patch": _get_tools_requests_patch,
"requests_put": _get_tools_requests_put,
"requests_delete": _get_tools_requests_delete,
"terminal": _get_terminal,
}
def _get_llm_math(llm: BaseLanguageModel) -> BaseTool:
try:
from langchain.chains.llm_math.base import LLMMathChain
except ImportError:
raise ImportError(
"LLM Math tools require the library `langchain` to be installed."
" Please install it with `pip install langchain`."
)
return Tool(
name="Calculator",
description="Useful for when you need to answer questions about math.",
func=LLMMathChain.from_llm(llm=llm).run,
coroutine=LLMMathChain.from_llm(llm=llm).arun,
)
def _get_open_meteo_api(llm: BaseLanguageModel) -> BaseTool:
try:
from langchain.chains.api.base import APIChain
from langchain.chains.api import (
open_meteo_docs,
)
except ImportError:
raise ImportError(
"API tools require the library `langchain` to be installed."
" Please install it with `pip install langchain`."
)
chain = APIChain.from_llm_and_api_docs(
llm,
open_meteo_docs.OPEN_METEO_DOCS,
limit_to_domains=["https://api.open-meteo.com/"],
)
return Tool(
name="Open-Meteo-API",
description="Useful for when you want to get weather information from the OpenMeteo API. The input should be a question in natural language that this API can answer.",
func=chain.run,
)
_LLM_TOOLS: Dict[str, Callable[[BaseLanguageModel], BaseTool]] = {
"llm-math": _get_llm_math,
"open-meteo-api": _get_open_meteo_api,
}
def _get_news_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
news_api_key = kwargs["news_api_key"]
try:
from langchain.chains.api.base import APIChain
from langchain.chains.api import (
news_docs,
)
except ImportError:
raise ImportError(
"API tools require the library `langchain` to be installed."
" Please install it with `pip install langchain`."
)
chain = APIChain.from_llm_and_api_docs(
llm,
news_docs.NEWS_DOCS,
headers={"X-Api-Key": news_api_key},
limit_to_domains=["https://newsapi.org/"],
)
return Tool(
name="News-API",
description="Use this when you want to get information about the top headlines of current news stories. The input should be a question in natural language that this API can answer.",
func=chain.run,
)
def _get_tmdb_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
tmdb_bearer_token = kwargs["tmdb_bearer_token"]
try:
from langchain.chains.api.base import APIChain
from langchain.chains.api import (
tmdb_docs,
)
except ImportError:
raise ImportError(
"API tools require the library `langchain` to be installed."
" Please install it with `pip install langchain`."
)
chain = APIChain.from_llm_and_api_docs(
llm,
tmdb_docs.TMDB_DOCS,
headers={"Authorization": f"Bearer {tmdb_bearer_token}"},
limit_to_domains=["https://api.themoviedb.org/"],
)
return Tool(
name="TMDB-API",
description="Useful for when you want to get information from The Movie Database. The input should be a question in natural language that this API can answer.",
func=chain.run,
)
def _get_podcast_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
listen_api_key = kwargs["listen_api_key"]
try:
from langchain.chains.api.base import APIChain
from langchain.chains.api import (
podcast_docs,
)
except ImportError:
raise ImportError(
"API tools require the library `langchain` to be installed."
" Please install it with `pip install langchain`."
)
chain = APIChain.from_llm_and_api_docs(
llm,
podcast_docs.PODCAST_DOCS,
headers={"X-ListenAPI-Key": listen_api_key},
limit_to_domains=["https://listen-api.listennotes.com/"],
)
return Tool(
name="Podcast-API",
description="Use the Listen Notes Podcast API to search all podcasts or episodes. The input should be a question in natural language that this API can answer.",
func=chain.run,
)
def _get_lambda_api(**kwargs: Any) -> BaseTool:
return Tool(
name=kwargs["awslambda_tool_name"],
description=kwargs["awslambda_tool_description"],
func=LambdaWrapper(**kwargs).run,
)
def _get_wolfram_alpha(**kwargs: Any) -> BaseTool:
return WolframAlphaQueryRun(api_wrapper=WolframAlphaAPIWrapper(**kwargs))
def _get_google_search(**kwargs: Any) -> BaseTool:
return GoogleSearchRun(api_wrapper=GoogleSearchAPIWrapper(**kwargs))
def _get_merriam_webster(**kwargs: Any) -> BaseTool:
return MerriamWebsterQueryRun(api_wrapper=MerriamWebsterAPIWrapper(**kwargs))
def _get_wikipedia(**kwargs: Any) -> BaseTool:
return WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper(**kwargs))
def _get_arxiv(**kwargs: Any) -> BaseTool:
return ArxivQueryRun(api_wrapper=ArxivAPIWrapper(**kwargs))
def _get_golden_query(**kwargs: Any) -> BaseTool:
return GoldenQueryRun(api_wrapper=GoldenQueryAPIWrapper(**kwargs))
def _get_pubmed(**kwargs: Any) -> BaseTool:
return PubmedQueryRun(api_wrapper=PubMedAPIWrapper(**kwargs))
def _get_google_books(**kwargs: Any) -> BaseTool:
from langchain_community.tools.google_books import GoogleBooksQueryRun
return GoogleBooksQueryRun(api_wrapper=GoogleBooksAPIWrapper(**kwargs))
def _get_google_jobs(**kwargs: Any) -> BaseTool:
return GoogleJobsQueryRun(api_wrapper=GoogleJobsAPIWrapper(**kwargs))
def _get_google_lens(**kwargs: Any) -> BaseTool:
return GoogleLensQueryRun(api_wrapper=GoogleLensAPIWrapper(**kwargs))
def _get_google_serper(**kwargs: Any) -> BaseTool:
return GoogleSerperRun(api_wrapper=GoogleSerperAPIWrapper(**kwargs))
def _get_google_scholar(**kwargs: Any) -> BaseTool:
return GoogleScholarQueryRun(api_wrapper=GoogleScholarAPIWrapper(**kwargs))
def _get_google_finance(**kwargs: Any) -> BaseTool:
return GoogleFinanceQueryRun(api_wrapper=GoogleFinanceAPIWrapper(**kwargs))
def _get_google_trends(**kwargs: Any) -> BaseTool:
return GoogleTrendsQueryRun(api_wrapper=GoogleTrendsAPIWrapper(**kwargs))
def _get_google_serper_results_json(**kwargs: Any) -> BaseTool:
return GoogleSerperResults(api_wrapper=GoogleSerperAPIWrapper(**kwargs))
def _get_google_search_results_json(**kwargs: Any) -> BaseTool:
return GoogleSearchResults(api_wrapper=GoogleSearchAPIWrapper(**kwargs))
def _get_searchapi(**kwargs: Any) -> BaseTool:
return SearchAPIRun(api_wrapper=SearchApiAPIWrapper(**kwargs))
def _get_searchapi_results_json(**kwargs: Any) -> BaseTool:
return SearchAPIResults(api_wrapper=SearchApiAPIWrapper(**kwargs))
def _get_serpapi(**kwargs: Any) -> BaseTool:
return Tool(
name="Search",
description="A search engine. Useful for when you need to answer questions about current events. Input should be a search query.",
func=SerpAPIWrapper(**kwargs).run,
coroutine=SerpAPIWrapper(**kwargs).arun,
)
def _get_stackexchange(**kwargs: Any) -> BaseTool:
return StackExchangeTool(api_wrapper=StackExchangeAPIWrapper(**kwargs))
def _get_dalle_image_generator(**kwargs: Any) -> Tool:
return Tool(
"Dall-E-Image-Generator",
DallEAPIWrapper(**kwargs).run,
"A wrapper around OpenAI DALL-E API. Useful for when you need to generate images from a text description. Input should be an image description.",
)
def _get_twilio(**kwargs: Any) -> BaseTool:
return Tool(
name="Text-Message",
description="Useful for when you need to send a text message to a provided phone number.",
func=TwilioAPIWrapper(**kwargs).run,
)
def _get_searx_search(**kwargs: Any) -> BaseTool:
return SearxSearchRun(wrapper=SearxSearchWrapper(**kwargs))
def _get_searx_search_results_json(**kwargs: Any) -> BaseTool:
wrapper_kwargs = {k: v for k, v in kwargs.items() if k != "num_results"}
return SearxSearchResults(wrapper=SearxSearchWrapper(**wrapper_kwargs), **kwargs)
def _get_bing_search(**kwargs: Any) -> BaseTool:
return BingSearchRun(api_wrapper=BingSearchAPIWrapper(**kwargs))
def _get_metaphor_search(**kwargs: Any) -> BaseTool:
return MetaphorSearchResults(api_wrapper=MetaphorSearchAPIWrapper(**kwargs))
def _get_ddg_search(**kwargs: Any) -> BaseTool:
return DuckDuckGoSearchRun(api_wrapper=DuckDuckGoSearchAPIWrapper(**kwargs))
def _get_human_tool(**kwargs: Any) -> BaseTool:
return HumanInputRun(**kwargs)
def _get_scenexplain(**kwargs: Any) -> BaseTool:
return SceneXplainTool(**kwargs)
def _get_graphql_tool(**kwargs: Any) -> BaseTool:
return BaseGraphQLTool(graphql_wrapper=GraphQLAPIWrapper(**kwargs))
def _get_openweathermap(**kwargs: Any) -> BaseTool:
return OpenWeatherMapQueryRun(api_wrapper=OpenWeatherMapAPIWrapper(**kwargs))
def _get_dataforseo_api_search(**kwargs: Any) -> BaseTool:
return DataForSeoAPISearchRun(api_wrapper=DataForSeoAPIWrapper(**kwargs))
def _get_dataforseo_api_search_json(**kwargs: Any) -> BaseTool:
return DataForSeoAPISearchResults(api_wrapper=DataForSeoAPIWrapper(**kwargs))
def _get_eleven_labs_text2speech(**kwargs: Any) -> BaseTool:
return ElevenLabsText2SpeechTool(**kwargs)
def _get_memorize(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
return Memorize(llm=llm) # type: ignore[arg-type]
def _get_google_cloud_texttospeech(**kwargs: Any) -> BaseTool:
return GoogleCloudTextToSpeechTool(**kwargs)
def _get_file_management_tool(**kwargs: Any) -> BaseTool:
return ReadFileTool(**kwargs)
def _get_reddit_search(**kwargs: Any) -> BaseTool:
return RedditSearchRun(api_wrapper=RedditSearchAPIWrapper(**kwargs))
_EXTRA_LLM_TOOLS: Dict[
str,
Tuple[Callable[[Arg(BaseLanguageModel, "llm"), KwArg(Any)], BaseTool], List[str]],
] = {
"news-api": (_get_news_api, ["news_api_key"]),
"tmdb-api": (_get_tmdb_api, ["tmdb_bearer_token"]),
"podcast-api": (_get_podcast_api, ["listen_api_key"]),
"memorize": (_get_memorize, []),
}
_EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {
"wolfram-alpha": (_get_wolfram_alpha, ["wolfram_alpha_appid"]),
"google-search": (_get_google_search, ["google_api_key", "google_cse_id"]),
"google-search-results-json": (
_get_google_search_results_json,
["google_api_key", "google_cse_id", "num_results"],
),
"searx-search-results-json": (
_get_searx_search_results_json,
["searx_host", "engines", "num_results", "aiosession"],
),
"bing-search": (_get_bing_search, ["bing_subscription_key", "bing_search_url"]),
"metaphor-search": (_get_metaphor_search, ["metaphor_api_key"]),
"ddg-search": (_get_ddg_search, []),
"google-books": (_get_google_books, ["google_books_api_key"]),
"google-lens": (_get_google_lens, ["serp_api_key"]),
"google-serper": (_get_google_serper, ["serper_api_key", "aiosession"]),
"google-scholar": (
_get_google_scholar,
["top_k_results", "hl", "lr", "serp_api_key"],
),
"google-finance": (
_get_google_finance,
["serp_api_key"],
),
"google-trends": (
_get_google_trends,
["serp_api_key"],
),
"google-jobs": (
_get_google_jobs,
["serp_api_key"],
),
"google-serper-results-json": (
_get_google_serper_results_json,
["serper_api_key", "aiosession"],
),
"searchapi": (_get_searchapi, ["searchapi_api_key", "aiosession"]),
"searchapi-results-json": (
_get_searchapi_results_json,
["searchapi_api_key", "aiosession"],
),
"serpapi": (_get_serpapi, ["serpapi_api_key", "aiosession"]),
"dalle-image-generator": (_get_dalle_image_generator, ["openai_api_key"]),
"twilio": (_get_twilio, ["account_sid", "auth_token", "from_number"]),
"searx-search": (_get_searx_search, ["searx_host", "engines", "aiosession"]),
"merriam-webster": (_get_merriam_webster, ["merriam_webster_api_key"]),
"wikipedia": (_get_wikipedia, ["top_k_results", "lang"]),
"arxiv": (
_get_arxiv,
["top_k_results", "load_max_docs", "load_all_available_meta"],
),
"golden-query": (_get_golden_query, ["golden_api_key"]),
"pubmed": (_get_pubmed, ["top_k_results"]),
"human": (_get_human_tool, ["prompt_func", "input_func"]),
"awslambda": (
_get_lambda_api,
["awslambda_tool_name", "awslambda_tool_description", "function_name"],
),
"stackexchange": (_get_stackexchange, []),
"sceneXplain": (_get_scenexplain, []),
"graphql": (
_get_graphql_tool,
["graphql_endpoint", "custom_headers", "fetch_schema_from_transport"],
),
"openweathermap-api": (_get_openweathermap, ["openweathermap_api_key"]),
"dataforseo-api-search": (
_get_dataforseo_api_search,
["api_login", "api_password", "aiosession"],
),
"dataforseo-api-search-json": (
_get_dataforseo_api_search_json,
["api_login", "api_password", "aiosession"],
),
"eleven_labs_text2speech": (_get_eleven_labs_text2speech, ["elevenlabs_api_key"]),
"google_cloud_texttospeech": (_get_google_cloud_texttospeech, []),
"read_file": (_get_file_management_tool, []),
"reddit_search": (
_get_reddit_search,
["reddit_client_id", "reddit_client_secret", "reddit_user_agent"],
),
}
def _handle_callbacks(
callback_manager: Optional[BaseCallbackManager], callbacks: Callbacks
) -> Callbacks:
if callback_manager is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
if callbacks is not None:
raise ValueError(
"Cannot specify both callback_manager and callbacks arguments."
)
return callback_manager
return callbacks
def load_huggingface_tool(
task_or_repo_id: str,
model_repo_id: Optional[str] = None,
token: Optional[str] = None,
remote: bool = False,
**kwargs: Any,
) -> BaseTool:
"""Loads a tool from the HuggingFace Hub.
Args:
task_or_repo_id: Task or model repo id.
model_repo_id: Optional model repo id. Defaults to None.
token: Optional token. Defaults to None.
remote: Optional remote. Defaults to False.
kwargs: Additional keyword arguments.
Returns:
A tool.
Raises:
ImportError: If the required libraries are not installed.
NotImplementedError: If multimodal outputs or inputs are not supported.
"""
try:
from transformers import load_tool
except ImportError:
raise ImportError(
"HuggingFace tools require the libraries `transformers>=4.29.0`"
" and `huggingface_hub>=0.14.1` to be installed."
" Please install it with"
" `pip install --upgrade transformers huggingface_hub`."
)
hf_tool = load_tool(
task_or_repo_id,
model_repo_id=model_repo_id,
token=token,
remote=remote,
**kwargs,
)
outputs = hf_tool.outputs
if set(outputs) != {"text"}:
raise NotImplementedError("Multimodal outputs not supported yet.")
inputs = hf_tool.inputs
if set(inputs) != {"text"}:
raise NotImplementedError("Multimodal inputs not supported yet.")
return Tool.from_function(
hf_tool.__call__, name=hf_tool.name, description=hf_tool.description
)
def raise_dangerous_tools_exception(name: str) -> None:
raise ValueError(
f"{name} is a dangerous tool. You cannot use it without opting in "
"by setting allow_dangerous_tools to True. "
"Most tools have some inherit risk to them merely because they are "
'allowed to interact with the "real world".'
"Please refer to LangChain security guidelines "
"to https://python.langchain.com/docs/security."
"Some tools have been designated as dangerous because they pose "
"risk that is not intuitively obvious. For example, a tool that "
"allows an agent to make requests to the web, can also be used "
"to make requests to a server that is only accessible from the "
"server hosting the code."
"Again, all tools carry some risk, and it's your responsibility to "
"understand which tools you're using and the risks associated with "
"them."
)
def load_tools(
tool_names: List[str],
llm: Optional[BaseLanguageModel] = None,
callbacks: Callbacks = None,
allow_dangerous_tools: bool = False,
**kwargs: Any,
) -> List[BaseTool]:
"""Load tools based on their name.
Tools allow agents to interact with various resources and services like
APIs, databases, file systems, etc.
Please scope the permissions of each tools to the minimum required for the
application.
For example, if an application only needs to read from a database,
the database tool should not be given write permissions. Moreover
consider scoping the permissions to only allow accessing specific
tables and impose user-level quota for limiting resource usage.
Please read the APIs of the individual tools to determine which configuration
they support.
See [Security](https://python.langchain.com/docs/security) for more information.
Args:
tool_names: name of tools to load.
llm: An optional language model may be needed to initialize certain tools.
Defaults to None.
callbacks: Optional callback manager or list of callback handlers.
If not provided, default global callback manager will be used.
allow_dangerous_tools: Optional flag to allow dangerous tools.
Tools that contain some level of risk.
Please use with caution and read the documentation of these tools
to understand the risks and how to mitigate them.
Refer to https://python.langchain.com/docs/security
for more information.
Please note that this list may not be fully exhaustive.
It is your responsibility to understand which tools
you're using and the risks associated with them.
Defaults to False.
kwargs: Additional keyword arguments.
Returns:
List of tools.
Raises:
ValueError: If the tool name is unknown.
ValueError: If the tool requires an LLM to be provided.
ValueError: If the tool requires some parameters that were not provided.
ValueError: If the tool is a dangerous tool and allow_dangerous_tools is False.
"""
tools = []
callbacks = _handle_callbacks(
callback_manager=kwargs.get("callback_manager"), callbacks=callbacks
)
for name in tool_names:
if name in DANGEROUS_TOOLS and not allow_dangerous_tools:
raise_dangerous_tools_exception(name)
if name in {"requests"}:
warnings.warn(
"tool name `requests` is deprecated - "
"please use `requests_all` or specify the requests method"
)
if name == "requests_all":
# expand requests into various methods
if not allow_dangerous_tools:
raise_dangerous_tools_exception(name)
requests_method_tools = [
_tool for _tool in DANGEROUS_TOOLS if _tool.startswith("requests_")
]
tool_names.extend(requests_method_tools)
elif name in _BASE_TOOLS:
tools.append(_BASE_TOOLS[name]())
elif name in DANGEROUS_TOOLS:
tools.append(DANGEROUS_TOOLS[name]())
elif name in _LLM_TOOLS:
if llm is None:
raise ValueError(f"Tool {name} requires an LLM to be provided")
tool = _LLM_TOOLS[name](llm)
tools.append(tool)
elif name in _EXTRA_LLM_TOOLS:
if llm is None:
raise ValueError(f"Tool {name} requires an LLM to be provided")
_get_llm_tool_func, extra_keys = _EXTRA_LLM_TOOLS[name]
missing_keys = set(extra_keys).difference(kwargs)
if missing_keys:
raise ValueError(
f"Tool {name} requires some parameters that were not "
f"provided: {missing_keys}"
)
sub_kwargs = {k: kwargs[k] for k in extra_keys}
tool = _get_llm_tool_func(llm=llm, **sub_kwargs)
tools.append(tool)
elif name in _EXTRA_OPTIONAL_TOOLS:
_get_tool_func, extra_keys = _EXTRA_OPTIONAL_TOOLS[name]
sub_kwargs = {k: kwargs[k] for k in extra_keys if k in kwargs}
tool = _get_tool_func(**sub_kwargs)
tools.append(tool)
else:
raise ValueError(f"Got unknown tool {name}")
if callbacks is not None:
for tool in tools:
tool.callbacks = callbacks
return tools
def get_all_tool_names() -> List[str]:
"""Get a list of all possible tool names."""
return (
list(_BASE_TOOLS)
+ list(_EXTRA_OPTIONAL_TOOLS)
+ list(_EXTRA_LLM_TOOLS)
+ list(_LLM_TOOLS)
+ list(DANGEROUS_TOOLS)
)

View File

@@ -1 +0,0 @@
"""MultiOn Toolkit."""

View File

@@ -1,35 +0,0 @@
"""MultiOn agent."""
from __future__ import annotations
from typing import List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict
from langchain_community.tools.multion.close_session import MultionCloseSession
from langchain_community.tools.multion.create_session import MultionCreateSession
from langchain_community.tools.multion.update_session import MultionUpdateSession
class MultionToolkit(BaseToolkit):
"""Toolkit for interacting with the Browser Agent.
**Security Note**: This toolkit contains tools that interact with the
user's browser via the multion API which grants an agent
access to the user's browser.
Please review the documentation for the multion API to understand
the security implications of using this toolkit.
See https://python.langchain.com/docs/security for more information.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [MultionCreateSession(), MultionUpdateSession(), MultionCloseSession()]

View File

@@ -1 +0,0 @@
"""NASA Toolkit"""

View File

@@ -1,62 +0,0 @@
from typing import Dict, List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.nasa.prompt import (
NASA_CAPTIONS_PROMPT,
NASA_MANIFEST_PROMPT,
NASA_METADATA_PROMPT,
NASA_SEARCH_PROMPT,
)
from langchain_community.tools.nasa.tool import NasaAction
from langchain_community.utilities.nasa import NasaAPIWrapper
class NasaToolkit(BaseToolkit):
"""Nasa Toolkit.
Parameters:
tools: List[BaseTool]. The tools in the toolkit. Default is an empty list.
"""
tools: List[BaseTool] = []
@classmethod
def from_nasa_api_wrapper(cls, nasa_api_wrapper: NasaAPIWrapper) -> "NasaToolkit":
operations: List[Dict] = [
{
"mode": "search_media",
"name": "Search NASA Image and Video Library media",
"description": NASA_SEARCH_PROMPT,
},
{
"mode": "get_media_metadata_manifest",
"name": "Get NASA Image and Video Library media metadata manifest",
"description": NASA_MANIFEST_PROMPT,
},
{
"mode": "get_media_metadata_location",
"name": "Get NASA Image and Video Library media metadata location",
"description": NASA_METADATA_PROMPT,
},
{
"mode": "get_video_captions_location",
"name": "Get NASA Image and Video Library video captions location",
"description": NASA_CAPTIONS_PROMPT,
},
]
tools = [
NasaAction(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=nasa_api_wrapper,
)
for action in operations
]
return cls(tools=tools) # type: ignore[arg-type]
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -1,79 +0,0 @@
"""Tool for interacting with a single API with natural language definition."""
from __future__ import annotations
from typing import Any, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import Tool
from langchain_community.chains.openapi.chain import OpenAPIEndpointChain
from langchain_community.tools.openapi.utils.api_models import APIOperation
from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec
from langchain_community.utilities.requests import Requests
class NLATool(Tool):
"""Natural Language API Tool."""
@classmethod
def from_open_api_endpoint_chain(
cls, chain: OpenAPIEndpointChain, api_title: str
) -> "NLATool":
"""Convert an endpoint chain to an API endpoint tool.
Args:
chain: The endpoint chain.
api_title: The title of the API.
Returns:
The API endpoint tool.
"""
expanded_name = (
f"{api_title.replace(' ', '_')}.{chain.api_operation.operation_id}"
)
description = (
f"I'm an AI from {api_title}. Instruct what you want,"
" and I'll assist via an API with description:"
f" {chain.api_operation.description}"
)
return cls(name=expanded_name, func=chain.run, description=description)
@classmethod
def from_llm_and_method(
cls,
llm: BaseLanguageModel,
path: str,
method: str,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
**kwargs: Any,
) -> "NLATool":
"""Instantiate the tool from the specified path and method.
Args:
llm: The language model to use.
path: The path of the API.
method: The method of the API.
spec: The OpenAPI spec.
requests: Optional requests object. Default is None.
verbose: Whether to print verbose output. Default is False.
return_intermediate_steps: Whether to return intermediate steps.
Default is False.
kwargs: Additional arguments.
Returns:
The tool.
"""
api_operation = APIOperation.from_openapi_spec(spec, path, method)
chain = OpenAPIEndpointChain.from_api_operation(
api_operation,
llm,
requests=requests,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
**kwargs,
)
return cls.from_open_api_endpoint_chain(chain, spec.info.title)

View File

@@ -1,150 +0,0 @@
from __future__ import annotations
from typing import Any, List, Optional, Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import Field
from langchain_community.agent_toolkits.nla.tool import NLATool
from langchain_community.tools.openapi.utils.openapi_utils import OpenAPISpec
from langchain_community.tools.plugin import AIPlugin
from langchain_community.utilities.requests import Requests
class NLAToolkit(BaseToolkit):
"""Natural Language API Toolkit.
*Security Note*: This toolkit creates tools that enable making calls
to an Open API compliant API.
The tools created by this toolkit may be able to make GET, POST,
PATCH, PUT, DELETE requests to any of the exposed endpoints on
the API.
Control access to who can use this toolkit.
See https://python.langchain.com/docs/security for more information.
"""
nla_tools: Sequence[NLATool] = Field(...)
"""List of API Endpoint Tools."""
def get_tools(self) -> List[BaseTool]:
"""Get the tools for all the API operations."""
return list(self.nla_tools)
@staticmethod
def _get_http_operation_tools(
llm: BaseLanguageModel,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> List[NLATool]:
"""Get the tools for all the API operations."""
if not spec.paths:
return []
http_operation_tools = []
for path in spec.paths:
for method in spec.get_methods_for_path(path):
endpoint_tool = NLATool.from_llm_and_method(
llm=llm,
path=path,
method=method,
spec=spec,
requests=requests,
verbose=verbose,
**kwargs,
)
http_operation_tools.append(endpoint_tool)
return http_operation_tools
@classmethod
def from_llm_and_spec(
cls,
llm: BaseLanguageModel,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit by creating tools for each operation.
Args:
llm: The language model to use.
spec: The OpenAPI spec.
requests: Optional requests object. Default is None.
verbose: Whether to print verbose output. Default is False.
kwargs: Additional arguments.
Returns:
The toolkit.
"""
http_operation_tools = cls._get_http_operation_tools(
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
)
return cls(nla_tools=http_operation_tools)
@classmethod
def from_llm_and_url(
cls,
llm: BaseLanguageModel,
open_api_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL.
Args:
llm: The language model to use.
open_api_url: The URL of the OpenAPI spec.
requests: Optional requests object. Default is None.
verbose: Whether to print verbose output. Default is False.
kwargs: Additional arguments.
Returns:
The toolkit.
"""
spec = OpenAPISpec.from_url(open_api_url)
return cls.from_llm_and_spec(
llm=llm, spec=spec, requests=requests, verbose=verbose, **kwargs
)
@classmethod
def from_llm_and_ai_plugin(
cls,
llm: BaseLanguageModel,
ai_plugin: AIPlugin,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
spec = OpenAPISpec.from_url(ai_plugin.api.url)
# TODO: Merge optional Auth information with the `requests` argument
return cls.from_llm_and_spec(
llm=llm,
spec=spec,
requests=requests,
verbose=verbose,
**kwargs,
)
@classmethod
def from_llm_and_ai_plugin_url(
cls,
llm: BaseLanguageModel,
ai_plugin_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,
**kwargs: Any,
) -> NLAToolkit:
"""Instantiate the toolkit from an OpenAPI Spec URL"""
plugin = AIPlugin.from_url(ai_plugin_url)
return cls.from_llm_and_ai_plugin(
llm=llm, ai_plugin=plugin, requests=requests, verbose=verbose, **kwargs
)

View File

@@ -1 +0,0 @@
"""Office365 toolkit."""

View File

@@ -1,55 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.office365.create_draft_message import (
O365CreateDraftMessage,
)
from langchain_community.tools.office365.events_search import O365SearchEvents
from langchain_community.tools.office365.messages_search import O365SearchEmails
from langchain_community.tools.office365.send_event import O365SendEvent
from langchain_community.tools.office365.send_message import O365SendMessage
from langchain_community.tools.office365.utils import authenticate
if TYPE_CHECKING:
from O365 import Account
class O365Toolkit(BaseToolkit):
"""Toolkit for interacting with Office 365.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by reading, creating, updating, deleting
data associated with this service.
For example, this toolkit can be used search through emails and events,
send messages and event invites, and create draft messages.
Please make sure that the permissions given by this toolkit
are appropriate for your use case.
See https://python.langchain.com/docs/security for more information.
Parameters:
account: Optional. The Office 365 account. Default is None.
"""
account: Account = Field(default_factory=authenticate)
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
O365SearchEvents(),
O365CreateDraftMessage(),
O365SearchEmails(),
O365SendEvent(),
O365SendMessage(),
]

View File

@@ -1 +0,0 @@
"""OpenAPI spec agent."""

View File

@@ -1,106 +0,0 @@
"""OpenAPI spec agent."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_community.agent_toolkits.openapi.prompt import (
OPENAPI_PREFIX,
OPENAPI_SUFFIX,
)
from langchain_community.agent_toolkits.openapi.toolkit import OpenAPIToolkit
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
def create_openapi_agent(
llm: BaseLanguageModel,
toolkit: OpenAPIToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = OPENAPI_PREFIX,
suffix: str = OPENAPI_SUFFIX,
format_instructions: Optional[str] = None,
input_variables: Optional[List[str]] = None,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
return_intermediate_steps: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct an OpenAPI agent from an LLM and tools.
*Security Note*: When creating an OpenAPI agent, check the permissions
and capabilities of the underlying toolkit.
For example, if the default implementation of OpenAPIToolkit
uses the RequestsToolkit which contains tools to make arbitrary
network requests against any URL (e.g., GET, POST, PATCH, PUT, DELETE),
Control access to who can submit issue requests using this toolkit and
what network access it has.
See https://python.langchain.com/docs/security for more information.
Args:
llm: The language model to use.
toolkit: The OpenAPI toolkit.
callback_manager: Optional. The callback manager. Default is None.
prefix: Optional. The prefix for the prompt. Default is OPENAPI_PREFIX.
suffix: Optional. The suffix for the prompt. Default is OPENAPI_SUFFIX.
format_instructions: Optional. The format instructions for the prompt.
Default is None.
input_variables: Optional. The input variables for the prompt. Default is None.
max_iterations: Optional. The maximum number of iterations. Default is 15.
max_execution_time: Optional. The maximum execution time. Default is None.
early_stopping_method: Optional. The early stopping method. Default is "force".
verbose: Optional. Whether to print verbose output. Default is False.
return_intermediate_steps: Optional. Whether to return intermediate steps.
Default is False.
agent_executor_kwargs: Optional. Additional keyword arguments
for the agent executor.
kwargs: Additional arguments.
Returns:
The agent executor.
"""
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
tools = toolkit.get_tools()
prompt_params = (
{"format_instructions": format_instructions}
if format_instructions is not None
else {}
)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
**prompt_params,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@@ -1,464 +0,0 @@
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
import json
import re
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, cast
import yaml
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.tools import BaseTool, Tool
from pydantic import Field
from langchain_community.agent_toolkits.openapi.planner_prompt import (
API_CONTROLLER_PROMPT,
API_CONTROLLER_TOOL_DESCRIPTION,
API_CONTROLLER_TOOL_NAME,
API_ORCHESTRATOR_PROMPT,
API_PLANNER_PROMPT,
API_PLANNER_TOOL_DESCRIPTION,
API_PLANNER_TOOL_NAME,
PARSING_DELETE_PROMPT,
PARSING_GET_PROMPT,
PARSING_PATCH_PROMPT,
PARSING_POST_PROMPT,
PARSING_PUT_PROMPT,
REQUESTS_DELETE_TOOL_DESCRIPTION,
REQUESTS_GET_TOOL_DESCRIPTION,
REQUESTS_PATCH_TOOL_DESCRIPTION,
REQUESTS_POST_TOOL_DESCRIPTION,
REQUESTS_PUT_TOOL_DESCRIPTION,
)
from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec
from langchain_community.llms import OpenAI
from langchain_community.tools.requests.tool import BaseRequestsTool
from langchain_community.utilities.requests import RequestsWrapper
#
# Requests tools with LLM-instructed extraction of truncated responses.
#
# Of course, truncating so bluntly may lose a lot of valuable
# information in the response.
# However, the goal for now is to have only a single inference step.
MAX_RESPONSE_LENGTH = 5000
"""Maximum length of the response to be returned."""
Operation = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
from langchain.chains.llm import LLMChain
return LLMChain(
llm=OpenAI(),
prompt=prompt,
)
def _get_default_llm_chain_factory(
prompt: BasePromptTemplate,
) -> Callable[[], Any]:
"""Returns a default LLMChain factory."""
return partial(_get_default_llm_chain, prompt)
class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests GET tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_get"
"""Tool name."""
description: str = REQUESTS_GET_TOOL_DESCRIPTION
"""Tool description."""
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
data_params = data.get("params")
response: str = cast(
str, self.requests_wrapper.get(data["url"], params=data_params)
)
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests POST tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_post"
"""Tool name."""
description: str = REQUESTS_POST_TOOL_DESCRIPTION
"""Tool description."""
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response: str = cast(str, self.requests_wrapper.post(data["url"], data["data"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests PATCH tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_patch"
"""Tool name."""
description: str = REQUESTS_PATCH_TOOL_DESCRIPTION
"""Tool description."""
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PATCH_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response: str = cast(
str, self.requests_wrapper.patch(data["url"], data["data"])
)
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
"""Requests PUT tool with LLM-instructed extraction of truncated responses."""
name: str = "requests_put"
"""Tool name."""
description: str = REQUESTS_PUT_TOOL_DESCRIPTION
"""Tool description."""
response_length: int = MAX_RESPONSE_LENGTH
"""Maximum length of the response to be returned."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_PUT_PROMPT)
)
"""LLMChain used to extract the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response: str = cast(str, self.requests_wrapper.put(data["url"], data["data"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
"""Tool that sends a DELETE request and parses the response."""
name: str = "requests_delete"
"""The name of the tool."""
description: str = REQUESTS_DELETE_TOOL_DESCRIPTION
"""The description of the tool."""
response_length: Optional[int] = MAX_RESPONSE_LENGTH
"""The maximum length of the response."""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_DELETE_PROMPT)
)
"""The LLM chain used to parse the response."""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response: str = cast(str, self.requests_wrapper.delete(data["url"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
#
# Orchestrator, planner, controller.
#
def _create_api_planner_tool(
api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
) -> Tool:
from langchain.chains.llm import LLMChain
endpoint_descriptions = [
f"{name} {description}" for name, description, _ in api_spec.endpoints
]
prompt = PromptTemplate(
template=API_PLANNER_PROMPT,
input_variables=["query"],
partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
)
chain = LLMChain(llm=llm, prompt=prompt)
tool = Tool(
name=API_PLANNER_TOOL_NAME,
description=API_PLANNER_TOOL_DESCRIPTION,
func=chain.run,
)
return tool
def _create_api_controller_agent(
api_url: str,
api_docs: str,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
allow_dangerous_requests: bool,
allowed_operations: Sequence[Operation],
) -> Any:
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
tools: List[BaseTool] = []
if "GET" in allowed_operations:
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
tools.append(
RequestsGetToolWithParsing(
requests_wrapper=requests_wrapper,
llm_chain=get_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
)
if "POST" in allowed_operations:
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools.append(
RequestsPostToolWithParsing(
requests_wrapper=requests_wrapper,
llm_chain=post_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
)
if "PUT" in allowed_operations:
put_llm_chain = LLMChain(llm=llm, prompt=PARSING_PUT_PROMPT)
tools.append(
RequestsPutToolWithParsing(
requests_wrapper=requests_wrapper,
llm_chain=put_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
)
if "DELETE" in allowed_operations:
delete_llm_chain = LLMChain(llm=llm, prompt=PARSING_DELETE_PROMPT)
tools.append(
RequestsDeleteToolWithParsing(
requests_wrapper=requests_wrapper,
llm_chain=delete_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
)
if "PATCH" in allowed_operations:
patch_llm_chain = LLMChain(llm=llm, prompt=PARSING_PATCH_PROMPT)
tools.append(
RequestsPatchToolWithParsing(
requests_wrapper=requests_wrapper,
llm_chain=patch_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
)
)
if not tools:
raise ValueError("Tools not found")
prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"api_url": api_url,
"api_docs": api_docs,
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt),
allowed_tools=[tool.name for tool in tools],
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
def _create_api_controller_tool(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
allow_dangerous_requests: bool,
allowed_operations: Sequence[Operation],
) -> Tool:
"""Expose controller as a tool.
The tool is invoked with a plan from the planner, and dynamically
creates a controller agent with relevant documentation only to
constrain the context.
"""
base_url = api_spec.servers[0]["url"] # TODO: do better.
def _create_and_run_api_controller_agent(plan_str: str) -> str:
pattern = r"\b(GET|POST|PATCH|DELETE|PUT)\s+(/\S+)*"
matches = re.findall(pattern, plan_str)
endpoint_names = [
"{method} {route}".format(method=method, route=route.split("?")[0])
for method, route in matches
]
docs_str = ""
for endpoint_name in endpoint_names:
found_match = False
for name, _, docs in api_spec.endpoints:
regex_name = re.compile(re.sub("\\{.*?\\}", ".*", name))
if regex_name.match(endpoint_name):
found_match = True
docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
if not found_match:
raise ValueError(f"{endpoint_name} endpoint does not exist.")
agent = _create_api_controller_agent(
base_url,
docs_str,
requests_wrapper,
llm,
allow_dangerous_requests,
allowed_operations,
)
return agent.run(plan_str)
return Tool(
name=API_CONTROLLER_TOOL_NAME,
func=_create_and_run_api_controller_agent,
description=API_CONTROLLER_TOOL_DESCRIPTION,
)
def create_openapi_agent(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
shared_memory: Optional[Any] = None,
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = True,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
allow_dangerous_requests: bool = False,
allowed_operations: Sequence[Operation] = ("GET", "POST"),
**kwargs: Any,
) -> Any:
"""Construct an OpenAI API planner and controller for a given spec.
Inject credentials via requests_wrapper.
We use a top-level "orchestrator" agent to invoke the planner and controller,
rather than a top-level planner
that invokes a controller with its plan. This is to keep the planner simple.
You need to set allow_dangerous_requests to True to use Agent with BaseRequestsTool.
Requests can be dangerous and can lead to security vulnerabilities.
For example, users can ask a server to make a request to an internal
server. It's recommended to use requests through a proxy server
and avoid accepting inputs from untrusted sources without proper sandboxing.
Please see: https://python.langchain.com/docs/security
for further security information.
Args:
api_spec: The OpenAPI spec.
requests_wrapper: The requests wrapper.
llm: The language model.
shared_memory: Optional. The shared memory. Default is None.
callback_manager: Optional. The callback manager. Default is None.
verbose: Optional. Whether to print verbose output. Default is True.
agent_executor_kwargs: Optional. Additional keyword arguments
for the agent executor.
allow_dangerous_requests: Optional. Whether to allow dangerous requests.
Default is False.
allowed_operations: Optional. The allowed operations.
Default is ("GET", "POST").
kwargs: Additional arguments.
Returns:
The agent executor.
"""
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
tools = [
_create_api_planner_tool(api_spec, llm),
_create_api_controller_tool(
api_spec,
requests_wrapper,
llm,
allow_dangerous_requests,
allowed_operations,
),
]
prompt = PromptTemplate(
template=API_ORCHESTRATOR_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory),
allowed_tools=[tool.name for tool in tools],
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@@ -1,235 +0,0 @@
# flake8: noqa
from langchain_core.prompts.prompt import PromptTemplate
API_PLANNER_PROMPT = """You are a planner that plans a sequence of API calls to assist with user queries against an API.
You should:
1) evaluate whether the user query can be solved by the API documented below. If no, say why.
2) if yes, generate a plan of API calls and say what they are doing step by step.
3) If the plan includes a DELETE call, you should always return an ask from the User for authorization first unless the User has specifically asked to delete something.
You should only use API endpoints documented below ("Endpoints you can use:").
You can only use the DELETE tool if the User has specifically asked to delete something. Otherwise, you should return a request authorization from the User first.
Some user queries can be resolved in a single API call, but some will require several API calls.
The plan will be passed to an API controller that can format it into web requests and return the responses.
----
Here are some examples:
Fake endpoints for examples:
GET /user to get information about the current user
GET /products/search search across products
POST /users/{{id}}/cart to add products to a user's cart
PATCH /users/{{id}}/cart to update a user's cart
PUT /users/{{id}}/coupon to apply idempotent coupon to a user's cart
DELETE /users/{{id}}/cart to delete a user's cart
User query: tell me a joke
Plan: Sorry, this API's domain is shopping, not comedy.
User query: I want to buy a couch
Plan: 1. GET /products with a query param to search for couches
2. GET /user to find the user's id
3. POST /users/{{id}}/cart to add a couch to the user's cart
User query: I want to add a lamp to my cart
Plan: 1. GET /products with a query param to search for lamps
2. GET /user to find the user's id
3. PATCH /users/{{id}}/cart to add a lamp to the user's cart
User query: I want to add a coupon to my cart
Plan: 1. GET /user to find the user's id
2. PUT /users/{{id}}/coupon to apply the coupon
User query: I want to delete my cart
Plan: 1. GET /user to find the user's id
2. DELETE required. Did user specify DELETE or previously authorize? Yes, proceed.
3. DELETE /users/{{id}}/cart to delete the user's cart
User query: I want to start a new cart
Plan: 1. GET /user to find the user's id
2. DELETE required. Did user specify DELETE or previously authorize? No, ask for authorization.
3. Are you sure you want to delete your cart?
----
Here are endpoints you can use. Do not reference any of the endpoints above.
{endpoints}
----
User query: {query}
Plan:"""
API_PLANNER_TOOL_NAME = "api_planner"
API_PLANNER_TOOL_DESCRIPTION = f"Can be used to generate the right API calls to assist with a user query, like {API_PLANNER_TOOL_NAME}(query). Should always be called before trying to call the API controller."
# Execution.
API_CONTROLLER_PROMPT = """You are an agent that gets a sequence of API calls and given their documentation, should execute them and return the final response.
If you cannot complete them and run into issues, you should explain the issue. If you're unable to resolve an API call, you can retry the API call. When interacting with API objects, you should extract ids for inputs to other API calls but ids and names for outputs returned to the User.
Here is documentation on the API:
Base url: {api_url}
Endpoints:
{api_docs}
Here are tools to execute requests against the API: {tool_descriptions}
Starting below, you should follow this format:
Plan: the plan of API calls to execute
Thought: you should always think about what to do
Action: the action to take, should be one of the tools [{tool_names}]
Action Input: the input to the action
Observation: the output of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I am finished executing the plan (or, I cannot finish executing the plan without knowing some other information.)
Final Answer: the final output from executing the plan or missing information I'd need to re-plan correctly.
Begin!
Plan: {input}
Thought:
{agent_scratchpad}
"""
API_CONTROLLER_TOOL_NAME = "api_controller"
API_CONTROLLER_TOOL_DESCRIPTION = f"Can be used to execute a plan of API calls, like {API_CONTROLLER_TOOL_NAME}(plan)."
# Orchestrate planning + execution.
# The goal is to have an agent at the top-level (e.g. so it can recover from errors and re-plan) while
# keeping planning (and specifically the planning prompt) simple.
API_ORCHESTRATOR_PROMPT = """You are an agent that assists with user queries against API, things like querying information or creating resources.
Some user queries can be resolved in a single API call, particularly if you can find appropriate params from the OpenAPI spec; though some require several API calls.
You should always plan your API calls first, and then execute the plan second.
If the plan includes a DELETE call, be sure to ask the User for authorization first unless the User has specifically asked to delete something.
You should never return information without executing the api_controller tool.
Here are the tools to plan and execute API requests: {tool_descriptions}
Starting below, you should follow this format:
User query: the query a User wants help with related to the API
Thought: you should always think about what to do
Action: the action to take, should be one of the tools [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I am finished executing a plan and have the information the user asked for or the data the user asked to create
Final Answer: the final output from executing the plan
Example:
User query: can you add some trendy stuff to my shopping cart.
Thought: I should plan API calls first.
Action: api_planner
Action Input: I need to find the right API calls to add trendy items to the users shopping cart
Observation: 1) GET /items with params 'trending' is 'True' to get trending item ids
2) GET /user to get user
3) POST /cart to post the trending items to the user's cart
Thought: I'm ready to execute the API calls.
Action: api_controller
Action Input: 1) GET /items params 'trending' is 'True' to get trending item ids
2) GET /user to get user
3) POST /cart to post the trending items to the user's cart
...
Begin!
User query: {input}
Thought: I should generate a plan to help with this query and then copy that plan exactly to the controller.
{agent_scratchpad}"""
REQUESTS_GET_TOOL_DESCRIPTION = """Use this to GET content from a website.
Input to the tool should be a json string with 3 keys: "url", "params" and "output_instructions".
The value of "url" should be a string.
The value of "params" should be a dict of the needed and available parameters from the OpenAPI spec related to the endpoint.
If parameters are not needed, or not available, leave it empty.
The value of "output_instructions" should be instructions on what information to extract from the response,
for example the id(s) for a resource(s) that the GET request fetches.
"""
PARSING_GET_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_POST_TOOL_DESCRIPTION = """Use this when you want to POST to a website.
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
The value of "url" should be a string.
The value of "data" should be a dictionary of key-value pairs you want to POST to the url.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the POST request creates.
Always use double quotes for strings in the json string."""
PARSING_POST_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_PATCH_TOOL_DESCRIPTION = """Use this when you want to PATCH content on a website.
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
The value of "url" should be a string.
The value of "data" should be a dictionary of key-value pairs of the body params available in the OpenAPI spec you want to PATCH the content with at the url.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PATCH request creates.
Always use double quotes for strings in the json string."""
PARSING_PATCH_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_PUT_TOOL_DESCRIPTION = """Use this when you want to PUT to a website.
Input to the tool should be a json string with 3 keys: "url", "data", and "output_instructions".
The value of "url" should be a string.
The value of "data" should be a dictionary of key-value pairs you want to PUT to the url.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the PUT request creates.
Always use double quotes for strings in the json string."""
PARSING_PUT_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)
REQUESTS_DELETE_TOOL_DESCRIPTION = """ONLY USE THIS TOOL WHEN THE USER HAS SPECIFICALLY REQUESTED TO DELETE CONTENT FROM A WEBSITE.
Input to the tool should be a json string with 2 keys: "url", and "output_instructions".
The value of "url" should be a string.
The value of "output_instructions" should be instructions on what information to extract from the response, for example the id(s) for a resource(s) that the DELETE request creates.
Always use double quotes for strings in the json string.
ONLY USE THIS TOOL IF THE USER HAS SPECIFICALLY REQUESTED TO DELETE SOMETHING."""
PARSING_DELETE_PROMPT = PromptTemplate(
template="""Here is an API response:\n\n{response}\n\n====
Your task is to extract some information according to these instructions: {instructions}
When working with API objects, you should usually use ids over names. Do not return any ids or names that are not in the response.
If the response indicates an error, you should instead output a summary of the error.
Output:""",
input_variables=["response", "instructions"],
)

View File

@@ -1,29 +0,0 @@
# flake8: noqa
OPENAPI_PREFIX = """You are an agent designed to answer questions by making web requests to an API given the openapi spec.
If the question does not seem related to the API, return I don't know. Do not make up an answer.
Only use information provided by the tools to construct your response.
First, find the base URL needed to make the request.
Second, find the relevant paths needed to answer the question. Take note that, sometimes, you might need to make more than one request to more than one path to answer the question.
Third, find the required parameters needed to make the request. For GET requests, these are usually URL parameters and for POST requests, these are request body parameters.
Fourth, make the requests needed to answer the question. Ensure that you are sending the correct parameters to the request by checking which parameters are required. For parameters with a fixed set of values, please use the spec to look at which values are allowed.
Use the exact parameter names as listed in the spec, do not make up any names or abbreviate the names of parameters.
If you get a not found error, ensure that you are using a path that actually exists in the spec.
"""
OPENAPI_SUFFIX = """Begin!
Question: {input}
Thought: I should explore the spec to find the base server url for the API in the servers node.
{agent_scratchpad}"""
DESCRIPTION = """Can be used to answer questions about the openapi spec for the API. Always use this tool before trying to make a request.
Example inputs to this tool:
'What are the required query parameters for a GET request to the /bar endpoint?`
'What are the required parameters in the request body for a POST request to the /foo endpoint?'
Always give this tool a specific question."""

View File

@@ -1,82 +0,0 @@
"""Quick and dirty representation for OpenAPI specs."""
from dataclasses import dataclass
from typing import List, Tuple
from langchain_core.utils.json_schema import dereference_refs
@dataclass(frozen=True)
class ReducedOpenAPISpec:
"""A reduced OpenAPI spec.
This is a quick and dirty representation for OpenAPI specs.
Parameters:
servers: The servers in the spec.
description: The description of the spec.
endpoints: The endpoints in the spec.
"""
servers: List[dict]
description: str
endpoints: List[Tuple[str, str, dict]]
def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPISpec:
"""Simplify/distill/minify a spec somehow.
I want a smaller target for retrieval and (more importantly)
I want smaller results from retrieval.
I was hoping https://openapi.tools/ would have some useful bits
to this end, but doesn't seem so.
Args:
spec: The OpenAPI spec.
dereference: Whether to dereference the spec. Default is True.
Returns:
ReducedOpenAPISpec: The reduced OpenAPI spec.
"""
# 1. Consider only get, post, patch, put, delete endpoints.
endpoints = [
(f"{operation_name.upper()} {route}", docs.get("description"), docs)
for route, operation in spec["paths"].items()
for operation_name, docs in operation.items()
if operation_name in ["get", "post", "patch", "put", "delete"]
]
# 2. Replace any refs so that complete docs are retrieved.
# Note: probably want to do this post-retrieval, it blows up the size of the spec.
if dereference:
endpoints = [
(name, description, dereference_refs(docs, full_schema=spec))
for name, description, docs in endpoints
]
# 3. Strip docs down to required request args + happy path response.
def reduce_endpoint_docs(docs: dict) -> dict:
out = {}
if docs.get("description"):
out["description"] = docs.get("description")
if docs.get("parameters"):
out["parameters"] = [
parameter
for parameter in docs.get("parameters", [])
if parameter.get("required")
]
if "200" in docs["responses"]:
out["responses"] = docs["responses"]["200"]
if docs.get("requestBody"):
out["requestBody"] = docs.get("requestBody")
return out
endpoints = [
(name, description, reduce_endpoint_docs(docs))
for name, description, docs in endpoints
]
return ReducedOpenAPISpec(
servers=spec["servers"],
description=spec["info"].get("description", ""),
endpoints=endpoints,
)

View File

@@ -1,238 +0,0 @@
"""Requests toolkit."""
from __future__ import annotations
from typing import Any, List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool, Tool
from langchain_core.tools.base import BaseToolkit
from langchain_community.agent_toolkits.json.base import create_json_agent
from langchain_community.agent_toolkits.json.toolkit import JsonToolkit
from langchain_community.agent_toolkits.openapi.prompt import DESCRIPTION
from langchain_community.tools.json.tool import JsonSpec
from langchain_community.tools.requests.tool import (
RequestsDeleteTool,
RequestsGetTool,
RequestsPatchTool,
RequestsPostTool,
RequestsPutTool,
)
from langchain_community.utilities.requests import TextRequestsWrapper
class RequestsToolkit(BaseToolkit):
"""Toolkit for making REST requests.
*Security Note*: This toolkit contains tools to make GET, POST, PATCH, PUT,
and DELETE requests to an API.
Exercise care in who is allowed to use this toolkit. If exposing
to end users, consider that users will be able to make arbitrary
requests on behalf of the server hosting the code. For example,
users could ask the server to make a request to a private API
that is only accessible from the server.
Control access to who can submit issue requests using this toolkit and
what network access it has.
See https://python.langchain.com/docs/security for more information.
Setup:
Install ``langchain-community``.
.. code-block:: bash
pip install -U langchain-community
Key init args:
requests_wrapper: langchain_community.utilities.requests.GenericRequestsWrapper
wrapper for executing requests.
allow_dangerous_requests: bool
Defaults to False. Must "opt-in" to using dangerous requests by setting to True.
Instantiate:
.. code-block:: python
from langchain_community.agent_toolkits.openapi.toolkit import RequestsToolkit
from langchain_community.utilities.requests import TextRequestsWrapper
toolkit = RequestsToolkit(
requests_wrapper=TextRequestsWrapper(headers={}),
allow_dangerous_requests=ALLOW_DANGEROUS_REQUEST,
)
Tools:
.. code-block:: python
tools = toolkit.get_tools()
tools
.. code-block:: none
[RequestsGetTool(requests_wrapper=TextRequestsWrapper(headers={}, aiosession=None, auth=None, response_content_type='text', verify=True), allow_dangerous_requests=True),
RequestsPostTool(requests_wrapper=TextRequestsWrapper(headers={}, aiosession=None, auth=None, response_content_type='text', verify=True), allow_dangerous_requests=True),
RequestsPatchTool(requests_wrapper=TextRequestsWrapper(headers={}, aiosession=None, auth=None, response_content_type='text', verify=True), allow_dangerous_requests=True),
RequestsPutTool(requests_wrapper=TextRequestsWrapper(headers={}, aiosession=None, auth=None, response_content_type='text', verify=True), allow_dangerous_requests=True),
RequestsDeleteTool(requests_wrapper=TextRequestsWrapper(headers={}, aiosession=None, auth=None, response_content_type='text', verify=True), allow_dangerous_requests=True)]
Use within an agent:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
api_spec = \"\"\"
openapi: 3.0.0
info:
title: JSONPlaceholder API
version: 1.0.0
servers:
- url: https://jsonplaceholder.typicode.com
paths:
/posts:
get:
summary: Get posts
parameters: &id001
- name: _limit
in: query
required: false
schema:
type: integer
example: 2
description: Limit the number of results
\"\"\"
system_message = \"\"\"
You have access to an API to help answer user queries.
Here is documentation on the API:
{api_spec}
\"\"\".format(api_spec=api_spec)
llm = ChatOpenAI(model="gpt-4o-mini")
agent_executor = create_react_agent(llm, tools, state_modifier=system_message)
example_query = "Fetch the top two posts. What are their titles?"
events = agent_executor.stream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
for event in events:
event["messages"][-1].pretty_print()
.. code-block:: none
================================[1m Human Message [0m=================================
Fetch the top two posts. What are their titles?
==================================[1m Ai Message [0m==================================
Tool Calls:
requests_get (call_RV2SOyzCnV5h2sm4WPgG8fND)
Call ID: call_RV2SOyzCnV5h2sm4WPgG8fND
Args:
url: https://jsonplaceholder.typicode.com/posts?_limit=2
=================================[1m Tool Message [0m=================================
Name: requests_get
[
{
"userId": 1,
"id": 1,
"title": "sunt aut facere repellat provident occaecati excepturi optio reprehenderit",
"body": "quia et suscipit..."
},
{
"userId": 1,
"id": 2,
"title": "qui est esse",
"body": "est rerum tempore vitae..."
}
]
==================================[1m Ai Message [0m==================================
The titles of the top two posts are:
1. "sunt aut facere repellat provident occaecati excepturi optio reprehenderit"
2. "qui est esse"
""" # noqa: E501
requests_wrapper: TextRequestsWrapper
"""The requests wrapper."""
allow_dangerous_requests: bool = False
"""Allow dangerous requests. See documentation for details."""
def get_tools(self) -> List[BaseTool]:
"""Return a list of tools."""
return [
RequestsGetTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
RequestsPostTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
RequestsPatchTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
RequestsPutTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
RequestsDeleteTool(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
),
]
class OpenAPIToolkit(BaseToolkit):
"""Toolkit for interacting with an OpenAPI API.
*Security Note*: This toolkit contains tools that can read and modify
the state of a service; e.g., by creating, deleting, or updating,
reading underlying data.
For example, this toolkit can be used to delete data exposed via
an OpenAPI compliant API.
"""
json_agent: Any
"""The JSON agent."""
requests_wrapper: TextRequestsWrapper
"""The requests wrapper."""
allow_dangerous_requests: bool = False
"""Allow dangerous requests. See documentation for details."""
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
json_agent_tool = Tool(
name="json_explorer",
func=self.json_agent.run,
description=DESCRIPTION,
)
request_toolkit = RequestsToolkit(
requests_wrapper=self.requests_wrapper,
allow_dangerous_requests=self.allow_dangerous_requests,
)
return [*request_toolkit.get_tools(), json_agent_tool]
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
json_spec: JsonSpec,
requests_wrapper: TextRequestsWrapper,
allow_dangerous_requests: bool = False,
**kwargs: Any,
) -> OpenAPIToolkit:
"""Create json agent from llm, then initialize."""
json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs)
return cls(
json_agent=json_agent,
requests_wrapper=requests_wrapper,
allow_dangerous_requests=allow_dangerous_requests,
)

View File

@@ -1,7 +0,0 @@
"""Playwright browser toolkit."""
from langchain_community.agent_toolkits.playwright.toolkit import (
PlayWrightBrowserToolkit,
)
__all__ = ["PlayWrightBrowserToolkit"]

View File

@@ -1,122 +0,0 @@
"""Playwright web browser toolkit."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Optional, Type, cast
from langchain_core.tools import BaseTool, BaseToolkit
from pydantic import ConfigDict, model_validator
from langchain_community.tools.playwright.base import (
BaseBrowserTool,
lazy_import_playwright_browsers,
)
from langchain_community.tools.playwright.click import ClickTool
from langchain_community.tools.playwright.current_page import CurrentWebPageTool
from langchain_community.tools.playwright.extract_hyperlinks import (
ExtractHyperlinksTool,
)
from langchain_community.tools.playwright.extract_text import ExtractTextTool
from langchain_community.tools.playwright.get_elements import GetElementsTool
from langchain_community.tools.playwright.navigate import NavigateTool
from langchain_community.tools.playwright.navigate_back import NavigateBackTool
if TYPE_CHECKING:
from playwright.async_api import Browser as AsyncBrowser
from playwright.sync_api import Browser as SyncBrowser
else:
try:
# We do this so pydantic can resolve the types when instantiating
from playwright.async_api import Browser as AsyncBrowser
from playwright.sync_api import Browser as SyncBrowser
except ImportError:
pass
class PlayWrightBrowserToolkit(BaseToolkit):
"""Toolkit for PlayWright browser tools.
**Security Note**: This toolkit provides code to control a web-browser.
Careful if exposing this toolkit to end-users. The tools in the toolkit
are capable of navigating to arbitrary webpages, clicking on arbitrary
elements, and extracting arbitrary text and hyperlinks from webpages.
Specifically, by default this toolkit allows navigating to:
- Any URL (including any internal network URLs)
- And local files
If exposing to end-users, consider limiting network access to the
server that hosts the agent; in addition, consider it is advised
to create a custom NavigationTool wht an args_schema that limits the URLs
that can be navigated to (e.g., only allow navigating to URLs that
start with a particular prefix).
Remember to scope permissions to the minimal permissions necessary for
the application. If the default tool selection is not appropriate for
the application, consider creating a custom toolkit with the appropriate
tools.
See https://python.langchain.com/docs/security for more information.
Parameters:
sync_browser: Optional. The sync browser. Default is None.
async_browser: Optional. The async browser. Default is None.
"""
sync_browser: Optional["SyncBrowser"] = None
async_browser: Optional["AsyncBrowser"] = None
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
@model_validator(mode="before")
@classmethod
def validate_imports_and_browser_provided(cls, values: dict) -> Any:
"""Check that the arguments are valid."""
lazy_import_playwright_browsers()
if values.get("async_browser") is None and values.get("sync_browser") is None:
raise ValueError("Either async_browser or sync_browser must be specified.")
return values
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
tool_classes: List[Type[BaseBrowserTool]] = [
ClickTool,
NavigateTool,
NavigateBackTool,
ExtractTextTool,
ExtractHyperlinksTool,
GetElementsTool,
CurrentWebPageTool,
]
tools = [
tool_cls.from_browser(
sync_browser=self.sync_browser, async_browser=self.async_browser
)
for tool_cls in tool_classes
]
return cast(List[BaseTool], tools)
@classmethod
def from_browser(
cls,
sync_browser: Optional[SyncBrowser] = None,
async_browser: Optional[AsyncBrowser] = None,
) -> PlayWrightBrowserToolkit:
"""Instantiate the toolkit.
Args:
sync_browser: Optional. The sync browser. Default is None.
async_browser: Optional. The async browser. Default is None.
Returns:
The toolkit.
"""
# This is to raise a better error than the forward ref ones Pydantic would have
lazy_import_playwright_browsers()
return cls(sync_browser=sync_browser, async_browser=async_browser)

View File

@@ -1 +0,0 @@
"""Polygon Toolkit"""

View File

@@ -1,54 +0,0 @@
from typing import List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.polygon import (
PolygonAggregates,
PolygonFinancials,
PolygonLastQuote,
PolygonTickerNews,
)
from langchain_community.utilities.polygon import PolygonAPIWrapper
class PolygonToolkit(BaseToolkit):
"""Polygon Toolkit.
Parameters:
tools: List[BaseTool]. The tools in the toolkit.
"""
tools: List[BaseTool] = []
@classmethod
def from_polygon_api_wrapper(
cls, polygon_api_wrapper: PolygonAPIWrapper
) -> "PolygonToolkit":
"""Create a Polygon Toolkit from a Polygon API Wrapper.
Args:
polygon_api_wrapper: PolygonAPIWrapper. The Polygon API Wrapper.
Returns:
PolygonToolkit. The Polygon Toolkit.
"""
tools = [
PolygonAggregates(
api_wrapper=polygon_api_wrapper,
),
PolygonLastQuote(
api_wrapper=polygon_api_wrapper,
),
PolygonTickerNews(
api_wrapper=polygon_api_wrapper,
),
PolygonFinancials(
api_wrapper=polygon_api_wrapper,
),
]
return cls(tools=tools)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -1 +0,0 @@
"""Power BI agent."""

View File

@@ -1,94 +0,0 @@
"""Power BI agent."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_community.agent_toolkits.powerbi.prompt import (
POWERBI_PREFIX,
POWERBI_SUFFIX,
)
from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain_community.utilities.powerbi import PowerBIDataset
if TYPE_CHECKING:
from langchain.agents import AgentExecutor
def create_pbi_agent(
llm: BaseLanguageModel,
toolkit: Optional[PowerBIToolkit] = None,
powerbi: Optional[PowerBIDataset] = None,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = POWERBI_PREFIX,
suffix: str = POWERBI_SUFFIX,
format_instructions: Optional[str] = None,
examples: Optional[str] = None,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Power BI agent from an LLM and tools.
Args:
llm: The language model to use.
toolkit: Optional. The Power BI toolkit. Default is None.
powerbi: Optional. The Power BI dataset. Default is None.
callback_manager: Optional. The callback manager. Default is None.
prefix: Optional. The prefix for the prompt. Default is POWERBI_PREFIX.
suffix: Optional. The suffix for the prompt. Default is POWERBI_SUFFIX.
format_instructions: Optional. The format instructions for the prompt.
Default is None.
examples: Optional. The examples for the prompt. Default is None.
input_variables: Optional. The input variables for the prompt. Default is None.
top_k: Optional. The top k for the prompt. Default is 10.
verbose: Optional. Whether to print verbose output. Default is False.
agent_executor_kwargs: Optional. The agent executor kwargs. Default is None.
kwargs: Any. Additional keyword arguments.
Returns:
The agent executor.
"""
from langchain.agents import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
if toolkit is None:
if powerbi is None:
raise ValueError("Must provide either a toolkit or powerbi dataset")
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
tools = toolkit.get_tools()
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
prompt_params = (
{"format_instructions": format_instructions}
if format_instructions is not None
else {}
)
agent = ZeroShotAgent(
llm_chain=LLMChain(
llm=llm,
prompt=ZeroShotAgent.create_prompt(
tools,
prefix=prefix.format(top_k=top_k).format(tables=tables),
suffix=suffix,
input_variables=input_variables,
**prompt_params,
),
callback_manager=callback_manager,
verbose=verbose,
),
allowed_tools=[tool.name for tool in tools],
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@@ -1,91 +0,0 @@
"""Power BI agent."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.agent_toolkits.powerbi.prompt import (
POWERBI_CHAT_PREFIX,
POWERBI_CHAT_SUFFIX,
)
from langchain_community.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain_community.utilities.powerbi import PowerBIDataset
if TYPE_CHECKING:
from langchain.agents import AgentExecutor
from langchain.agents.agent import AgentOutputParser
from langchain.memory.chat_memory import BaseChatMemory
def create_pbi_chat_agent(
llm: BaseChatModel,
toolkit: Optional[PowerBIToolkit] = None,
powerbi: Optional[PowerBIDataset] = None,
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = POWERBI_CHAT_PREFIX,
suffix: str = POWERBI_CHAT_SUFFIX,
examples: Optional[str] = None,
input_variables: Optional[List[str]] = None,
memory: Optional[BaseChatMemory] = None,
top_k: int = 10,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Power BI agent from a Chat LLM and tools.
If you supply only a toolkit and no Power BI dataset, the same LLM is used for both.
Args:
llm: The language model to use.
toolkit: Optional. The Power BI toolkit. Default is None.
powerbi: Optional. The Power BI dataset. Default is None.
callback_manager: Optional. The callback manager. Default is None.
output_parser: Optional. The output parser. Default is None.
prefix: Optional. The prefix for the prompt. Default is POWERBI_CHAT_PREFIX.
suffix: Optional. The suffix for the prompt. Default is POWERBI_CHAT_SUFFIX.
examples: Optional. The examples for the prompt. Default is None.
input_variables: Optional. The input variables for the prompt. Default is None.
memory: Optional. The memory. Default is None.
top_k: Optional. The top k for the prompt. Default is 10.
verbose: Optional. Whether to print verbose output. Default is False.
agent_executor_kwargs: Optional. The agent executor kwargs. Default is None.
kwargs: Any. Additional keyword arguments.
Returns:
The agent executor.
"""
from langchain.agents import AgentExecutor
from langchain.agents.conversational_chat.base import ConversationalChatAgent
from langchain.memory import ConversationBufferMemory
if toolkit is None:
if powerbi is None:
raise ValueError("Must provide either a toolkit or powerbi dataset")
toolkit = PowerBIToolkit(powerbi=powerbi, llm=llm, examples=examples)
tools = toolkit.get_tools()
tables = powerbi.table_names if powerbi else toolkit.powerbi.table_names
agent = ConversationalChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
system_message=prefix.format(top_k=top_k).format(tables=tables),
human_message=suffix,
input_variables=input_variables,
callback_manager=callback_manager,
output_parser=output_parser,
verbose=verbose,
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
memory=memory
or ConversationBufferMemory(memory_key="chat_history", return_messages=True),
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@@ -1,37 +0,0 @@
# flake8: noqa
"""Prompts for PowerBI agent."""
POWERBI_PREFIX = """You are an agent designed to help users interact with a PowerBI Dataset.
Agent has access to a tool that can write a query based on the question and then run those against PowerBI, Microsofts business intelligence tool. The questions from the users should be interpreted as related to the dataset that is available and not general questions about the world. If the question does not seem related to the dataset, return "This does not appear to be part of this dataset." as the answer.
Given an input question, ask to run the questions against the dataset, then look at the results and return the answer, the answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
"""
POWERBI_SUFFIX = """Begin!
Question: {input}
Thought: I can first ask which tables I have, then how each table is defined and then ask the query tool the question I need, and finally create a nice sentence that answers the question.
{agent_scratchpad}"""
POWERBI_CHAT_PREFIX = """Assistant is a large language model built to help users interact with a PowerBI Dataset.
Assistant should try to create a correct and complete answer to the question from the user. If the user asks a question not related to the dataset it should return "This does not appear to be part of this dataset." as the answer. The user might make a mistake with the spelling of certain values, if you think that is the case, ask the user to confirm the spelling of the value and then run the query again. Unless the user specifies a specific number of examples they wish to obtain, and the results are too large, limit your query to at most {top_k} results, but make it clear when answering which field was used for the filtering. The user has access to these tables: {{tables}}.
The answer should be a complete sentence that answers the question, if multiple rows are asked find a way to write that in a easily readable format for a human, also make sure to represent numbers in readable ways, like 1M instead of 1000000.
"""
POWERBI_CHAT_SUFFIX = """TOOLS
------
Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question. The tools the human can use are:
{{tools}}
{format_instructions}
USER'S INPUT
--------------------
Here is the user's input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
{{{{input}}}}
"""

View File

@@ -1,117 +0,0 @@
"""Toolkit for interacting with a Power BI dataset."""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional, Union
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.powerbi.prompt import (
QUESTION_TO_QUERY_BASE,
SINGLE_QUESTION_TO_QUERY,
USER_INPUT,
)
from langchain_community.tools.powerbi.tool import (
InfoPowerBITool,
ListPowerBITool,
QueryPowerBITool,
)
from langchain_community.utilities.powerbi import PowerBIDataset
if TYPE_CHECKING:
from langchain.chains.llm import LLMChain
class PowerBIToolkit(BaseToolkit):
"""Toolkit for interacting with Power BI dataset.
*Security Note*: This toolkit interacts with an external service.
Control access to who can use this toolkit.
Make sure that the capabilities given by this toolkit to the calling
code are appropriately scoped to the application.
See https://python.langchain.com/docs/security for more information.
Parameters:
powerbi: The Power BI dataset.
llm: The language model to use.
examples: Optional. The examples for the prompt. Default is None.
max_iterations: Optional. The maximum iterations to run. Default is 5.
callback_manager: Optional. The callback manager. Default is None.
output_token_limit: The output token limit. Default is 4000.
tiktoken_model_name: Optional. The TikToken model name. Default is None.
"""
powerbi: PowerBIDataset = Field(exclude=True)
llm: Union[BaseLanguageModel, BaseChatModel] = Field(exclude=True)
examples: Optional[str] = None
max_iterations: int = 5
callback_manager: Optional[BaseCallbackManager] = None
output_token_limit: int = 4000
tiktoken_model_name: Optional[str] = None
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
QueryPowerBITool(
llm_chain=self._get_chain(),
powerbi=self.powerbi,
examples=self.examples,
max_iterations=self.max_iterations,
output_token_limit=self.output_token_limit,
tiktoken_model_name=self.tiktoken_model_name,
),
InfoPowerBITool(powerbi=self.powerbi),
ListPowerBITool(powerbi=self.powerbi),
]
def _get_chain(self) -> LLMChain:
"""Construct the chain based on the callback manager and model type."""
from langchain.chains.llm import LLMChain
if isinstance(self.llm, BaseLanguageModel):
return LLMChain(
llm=self.llm,
callback_manager=self.callback_manager
if self.callback_manager
else None,
prompt=PromptTemplate(
template=SINGLE_QUESTION_TO_QUERY,
input_variables=["tool_input", "tables", "schemas", "examples"],
),
)
system_prompt = SystemMessagePromptTemplate(
prompt=PromptTemplate(
template=QUESTION_TO_QUERY_BASE,
input_variables=["tables", "schemas", "examples"],
)
)
human_prompt = HumanMessagePromptTemplate(
prompt=PromptTemplate(
template=USER_INPUT,
input_variables=["tool_input"],
)
)
return LLMChain(
llm=self.llm,
callback_manager=self.callback_manager if self.callback_manager else None,
prompt=ChatPromptTemplate.from_messages([system_prompt, human_prompt]),
)

View File

@@ -1 +0,0 @@
"""Slack toolkit."""

View File

@@ -1,112 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.slack.get_channel import SlackGetChannel
from langchain_community.tools.slack.get_message import SlackGetMessage
from langchain_community.tools.slack.schedule_message import SlackScheduleMessage
from langchain_community.tools.slack.send_message import SlackSendMessage
from langchain_community.tools.slack.utils import login
if TYPE_CHECKING:
# This is for linting and IDE typehints
from slack_sdk import WebClient
else:
try:
# We do this so pydantic can resolve the types when instantiating
from slack_sdk import WebClient
except ImportError:
pass
class SlackToolkit(BaseToolkit):
"""Toolkit for interacting with Slack.
Parameters:
client: The Slack client.
Setup:
Install ``slack_sdk`` and set environment variable ``SLACK_USER_TOKEN``.
.. code-block:: bash
pip install -U slack_sdk
export SLACK_USER_TOKEN="your-user-token"
Key init args:
client: slack_sdk.WebClient
The Slack client.
Instantiate:
.. code-block:: python
from langchain_community.agent_toolkits import SlackToolkit
toolkit = SlackToolkit()
Tools:
.. code-block:: python
tools = toolkit.get_tools()
tools
.. code-block:: none
[SlackGetChannel(client=<slack_sdk.web.client.WebClient object at 0x113caa8c0>),
SlackGetMessage(client=<slack_sdk.web.client.WebClient object at 0x113caa4d0>),
SlackScheduleMessage(client=<slack_sdk.web.client.WebClient object at 0x113caa440>),
SlackSendMessage(client=<slack_sdk.web.client.WebClient object at 0x113caa410>)]
Use within an agent:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
llm = ChatOpenAI(model="gpt-4o-mini")
agent_executor = create_react_agent(llm, tools)
example_query = "When was the #general channel created?"
events = agent_executor.stream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
for event in events:
message = event["messages"][-1]
if message.type != "tool": # mask sensitive information
event["messages"][-1].pretty_print()
.. code-block:: none
================================[1m Human Message [0m=================================
When was the #general channel created?
==================================[1m Ai Message [0m==================================
Tool Calls:
get_channelid_name_dict (call_NXDkALjoOx97uF1v0CoZTqtJ)
Call ID: call_NXDkALjoOx97uF1v0CoZTqtJ
Args:
==================================[1m Ai Message [0m==================================
The #general channel was created on timestamp 1671043305.
""" # noqa: E501
client: WebClient = Field(default_factory=login)
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
SlackGetChannel(),
SlackGetMessage(),
SlackScheduleMessage(),
SlackSendMessage(),
]

View File

@@ -1 +0,0 @@
"""Spark SQL agent."""

View File

@@ -1,93 +0,0 @@
"""Spark SQL agent."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackManager, Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_community.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
def create_spark_sql_agent(
llm: BaseLanguageModel,
toolkit: SparkSQLToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX,
format_instructions: Optional[str] = None,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a Spark SQL agent from an LLM and tools.
Args:
llm: The language model to use.
toolkit: The Spark SQL toolkit.
callback_manager: Optional. The callback manager. Default is None.
callbacks: Optional. The callbacks. Default is None.
prefix: Optional. The prefix for the prompt. Default is SQL_PREFIX.
suffix: Optional. The suffix for the prompt. Default is SQL_SUFFIX.
format_instructions: Optional. The format instructions for the prompt.
Default is None.
input_variables: Optional. The input variables for the prompt. Default is None.
top_k: Optional. The top k for the prompt. Default is 10.
max_iterations: Optional. The maximum iterations to run. Default is 15.
max_execution_time: Optional. The maximum execution time. Default is None.
early_stopping_method: Optional. The early stopping method. Default is "force".
verbose: Optional. Whether to print verbose output. Default is False.
agent_executor_kwargs: Optional. The agent executor kwargs. Default is None.
kwargs: Any. Additional keyword arguments.
Returns:
The agent executor.
"""
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
tools = toolkit.get_tools()
prefix = prefix.format(top_k=top_k)
prompt_params = (
{"format_instructions": format_instructions}
if format_instructions is not None
else {}
)
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
**prompt_params,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
callbacks=callbacks,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
callbacks=callbacks,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@@ -1,21 +0,0 @@
# flake8: noqa
SQL_PREFIX = """You are an agent designed to interact with Spark SQL.
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""
SQL_SUFFIX = """Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query.
{agent_scratchpad}"""

View File

@@ -1,41 +0,0 @@
"""Toolkit for interacting with Spark SQL."""
from typing import List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.spark_sql.tool import (
InfoSparkSQLTool,
ListSparkSQLTool,
QueryCheckerTool,
QuerySparkSQLTool,
)
from langchain_community.utilities.spark_sql import SparkSQL
class SparkSQLToolkit(BaseToolkit):
"""Toolkit for interacting with Spark SQL.
Parameters:
db: SparkSQL. The Spark SQL database.
llm: BaseLanguageModel. The language model.
"""
db: SparkSQL = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return [
QuerySparkSQLTool(db=self.db),
InfoSparkSQLTool(db=self.db),
ListSparkSQLTool(db=self.db),
QueryCheckerTool(db=self.db, llm=self.llm),
]

View File

@@ -1 +0,0 @@
"""SQL agent."""

View File

@@ -1,241 +0,0 @@
"""SQL agent."""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_community.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
)
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_community.utilities.sql_database import SQLDatabase
def create_sql_agent(
llm: BaseLanguageModel,
toolkit: Optional[SQLDatabaseToolkit] = None,
agent_type: Optional[
Union[AgentType, Literal["openai-tools", "tool-calling"]]
] = None,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
format_instructions: Optional[str] = None,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
extra_tools: Sequence[BaseTool] = (),
*,
db: Optional[SQLDatabase] = None,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a SQL agent from an LLM and toolkit or database.
Args:
llm: Language model to use for the agent. If agent_type is "tool-calling" then
llm is expected to support tool calling.
toolkit: SQLDatabaseToolkit for the agent to use. Must provide exactly one of
'toolkit' or 'db'. Specify 'toolkit' if you want to use a different model
for the agent and the toolkit.
agent_type: One of "tool-calling", "openai-tools", "openai-functions", or
"zero-shot-react-description". Defaults to "zero-shot-react-description".
"tool-calling" is recommended over the legacy "openai-tools" and
"openai-functions" types.
callback_manager: DEPRECATED. Pass "callbacks" key into 'agent_executor_kwargs'
instead to pass constructor callbacks to AgentExecutor.
prefix: Prompt prefix string. Must contain variables "top_k" and "dialect".
suffix: Prompt suffix string. Default depends on agent type.
format_instructions: Formatting instructions to pass to
ZeroShotAgent.create_prompt() when 'agent_type' is
"zero-shot-react-description". Otherwise ignored.
input_variables: DEPRECATED.
top_k: Number of rows to query for by default.
max_iterations: Passed to AgentExecutor init.
max_execution_time: Passed to AgentExecutor init.
early_stopping_method: Passed to AgentExecutor init.
verbose: AgentExecutor verbosity.
agent_executor_kwargs: Arbitrary additional AgentExecutor args.
extra_tools: Additional tools to give to agent on top of the ones that come with
SQLDatabaseToolkit.
db: SQLDatabase from which to create a SQLDatabaseToolkit. Toolkit is created
using 'db' and 'llm'. Must provide exactly one of 'db' or 'toolkit'.
prompt: Complete agent prompt. prompt and {prefix, suffix, format_instructions,
input_variables} are mutually exclusive.
**kwargs: Arbitrary additional Agent args.
Returns:
An AgentExecutor with the specified agent_type agent.
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="tool-calling", verbose=True)
""" # noqa: E501
from langchain.agents import (
create_openai_functions_agent,
create_openai_tools_agent,
create_react_agent,
create_tool_calling_agent,
)
from langchain.agents.agent import (
AgentExecutor,
RunnableAgent,
RunnableMultiActionAgent,
)
from langchain.agents.agent_types import AgentType
if toolkit is None and db is None:
raise ValueError(
"Must provide exactly one of 'toolkit' or 'db'. Received neither."
)
if toolkit and db:
raise ValueError(
"Must provide exactly one of 'toolkit' or 'db'. Received both."
)
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db) # type: ignore[arg-type]
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
tools = toolkit.get_tools() + list(extra_tools)
if prefix is None:
prefix = SQL_PREFIX
if prompt is None:
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
else:
if "top_k" in prompt.input_variables:
prompt = prompt.partial(top_k=str(top_k))
if "dialect" in prompt.input_variables:
prompt = prompt.partial(dialect=toolkit.dialect)
if any(key in prompt.input_variables for key in ["table_info", "table_names"]):
db_context = toolkit.get_context()
if "table_info" in prompt.input_variables:
prompt = prompt.partial(table_info=db_context["table_info"])
tools = [
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
]
if "table_names" in prompt.input_variables:
prompt = prompt.partial(table_names=db_context["table_names"])
tools = [
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
if prompt is None:
from langchain.agents.mrkl import prompt as react_prompt
format_instructions = (
format_instructions or react_prompt.FORMAT_INSTRUCTIONS
)
template = "\n\n".join(
[
prefix,
"{tools}",
format_instructions,
suffix or SQL_SUFFIX,
]
)
prompt = PromptTemplate.from_template(template)
agent = RunnableAgent(
runnable=create_react_agent(llm, tools, prompt),
input_keys_arg=["input"],
return_keys_arg=["output"],
**kwargs,
)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
if prompt is None:
messages: List = [
SystemMessage(content=cast(str, prefix)),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore[arg-type]
input_keys_arg=["input"],
return_keys_arg=["output"],
**kwargs,
)
elif agent_type in ("openai-tools", "tool-calling"):
if prompt is None:
messages = [
SystemMessage(content=cast(str, prefix)),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
if agent_type == "openai-tools":
runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore[arg-type]
else:
runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore[arg-type]
agent = RunnableMultiActionAgent( # type: ignore[assignment]
runnable=runnable,
input_keys_arg=["input"],
return_keys_arg=["output"],
**kwargs,
)
else:
raise ValueError(
f"Agent type {agent_type} not supported at the moment. Must be one of "
"'tool-calling', 'openai-tools', 'openai-functions', or "
"'zero-shot-react-description'."
)
return AgentExecutor(
name="SQL Agent Executor",
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@@ -1,23 +0,0 @@
# flake8: noqa
SQL_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""
SQL_SUFFIX = """Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""
SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""

View File

@@ -1,139 +0,0 @@
"""Toolkit for interacting with an SQL database."""
from typing import List
from langchain_core.caches import BaseCache as BaseCache
from langchain_core.callbacks import Callbacks as Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDatabaseTool,
)
from langchain_community.tools.sql_database.tool import (
QuerySQLDataBaseTool as QuerySQLDataBaseTool, # keep import for backwards compat.
)
from langchain_community.utilities.sql_database import SQLDatabase
class SQLDatabaseToolkit(BaseToolkit):
"""SQLDatabaseToolkit for interacting with SQL databases.
Setup:
Install ``langchain-community``.
.. code-block:: bash
pip install -U langchain-community
Key init args:
db: SQLDatabase
The SQL database.
llm: BaseLanguageModel
The language model (for use with QuerySQLCheckerTool)
Instantiate:
.. code-block:: python
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_openai import ChatOpenAI
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(temperature=0)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
Tools:
.. code-block:: python
toolkit.get_tools()
Use within an agent:
.. code-block:: python
from langchain import hub
from langgraph.prebuilt import create_react_agent
# Pull prompt (or define your own)
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
system_message = prompt_template.format(dialect="SQLite", top_k=5)
# Create agent
agent_executor = create_react_agent(
llm, toolkit.get_tools(), state_modifier=system_message
)
# Query agent
example_query = "Which country's customers spent the most?"
events = agent_executor.stream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
for event in events:
event["messages"][-1].pretty_print()
""" # noqa: E501
db: SQLDatabase = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
@property
def dialect(self) -> str:
"""Return string representation of SQL dialect to use."""
return self.db.dialect
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"schema and sample rows for those tables. "
"Be sure that the tables actually exist by calling "
f"{list_sql_database_tool.name} first! "
"Example Input: table1, table2, table3"
)
info_sql_database_tool = InfoSQLDatabaseTool(
db=self.db, description=info_sql_database_tool_description
)
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result from the database. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
"to query the correct table fields."
)
query_sql_database_tool = QuerySQLDatabaseTool(
db=self.db, description=query_sql_database_tool_description
)
query_sql_checker_tool_description = (
"Use this tool to double check if your query is correct before executing "
"it. Always use this tool before executing a query with "
f"{query_sql_database_tool.name}!"
)
query_sql_checker_tool = QuerySQLCheckerTool(
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
)
return [
query_sql_database_tool,
info_sql_database_tool,
list_sql_database_tool,
query_sql_checker_tool,
]
def get_context(self) -> dict:
"""Return db context that you may want in agent prompt."""
return self.db.get_context()
SQLDatabaseToolkit.model_rebuild()

View File

@@ -1 +0,0 @@
"""Steam Toolkit."""

View File

@@ -1,62 +0,0 @@
"""Steam Toolkit."""
from typing import List
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.steam.prompt import (
STEAM_GET_GAMES_DETAILS,
STEAM_GET_RECOMMENDED_GAMES,
)
from langchain_community.tools.steam.tool import SteamWebAPIQueryRun
from langchain_community.utilities.steam import SteamWebAPIWrapper
class SteamToolkit(BaseToolkit):
"""Steam Toolkit.
Parameters:
tools: List[BaseTool]. The tools in the toolkit. Default is an empty list.
"""
tools: List[BaseTool] = []
@classmethod
def from_steam_api_wrapper(
cls, steam_api_wrapper: SteamWebAPIWrapper
) -> "SteamToolkit":
"""Create a Steam Toolkit from a Steam API Wrapper.
Args:
steam_api_wrapper: SteamWebAPIWrapper. The Steam API Wrapper.
Returns:
SteamToolkit. The Steam Toolkit.
"""
operations: List[dict] = [
{
"mode": "get_games_details",
"name": "Get Games Details",
"description": STEAM_GET_GAMES_DETAILS,
},
{
"mode": "get_recommended_games",
"name": "Get Recommended Games",
"description": STEAM_GET_RECOMMENDED_GAMES,
},
]
tools = [
SteamWebAPIQueryRun(
name=action["name"],
description=action["description"],
mode=action["mode"],
api_wrapper=steam_api_wrapper,
)
for action in operations
]
return cls(tools=tools) # type: ignore[arg-type]
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
return self.tools

View File

@@ -1,26 +0,0 @@
from pathlib import Path
from typing import Any
from langchain_core._api.path import as_import_path
def __getattr__(name: str) -> Any:
"""Get attr name."""
if name == "create_xorbits_agent":
# Get directory of langchain package
HERE = Path(__file__).parents[3]
here = as_import_path(Path(__file__).parent, relative_to=HERE)
old_path = "langchain." + here + "." + name
new_path = "langchain_experimental." + here + "." + name
raise ImportError(
"This agent has been moved to langchain experiment. "
"This agent relies on python REPL tool under the hood, so to use it "
"safely please sandbox the python REPL. "
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
"and https://github.com/langchain-ai/langchain/discussions/11680"
"To keep using this code as is, install langchain experimental and "
f"update your import statement from:\n `{old_path}` to `{new_path}`."
)
raise AttributeError(f"{name} does not exist")

View File

@@ -1 +0,0 @@
"""Zapier Toolkit."""

View File

@@ -1,79 +0,0 @@
"""[DEPRECATED] Zapier Toolkit."""
from typing import List
from langchain_core._api import warn_deprecated
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from langchain_community.tools.zapier.tool import ZapierNLARunAction
from langchain_community.utilities.zapier import ZapierNLAWrapper
class ZapierToolkit(BaseToolkit):
"""Zapier Toolkit.
Parameters:
tools: List[BaseTool]. The tools in the toolkit. Default is an empty list.
"""
tools: List[BaseTool] = []
@classmethod
def from_zapier_nla_wrapper(
cls, zapier_nla_wrapper: ZapierNLAWrapper
) -> "ZapierToolkit":
"""Create a toolkit from a ZapierNLAWrapper.
Args:
zapier_nla_wrapper: ZapierNLAWrapper. The Zapier NLA wrapper.
Returns:
ZapierToolkit. The Zapier toolkit.
"""
actions = zapier_nla_wrapper.list()
tools = [
ZapierNLARunAction(
action_id=action["id"],
zapier_description=action["description"],
params_schema=action["params"],
api_wrapper=zapier_nla_wrapper,
)
for action in actions
]
return cls(tools=tools) # type: ignore[arg-type]
@classmethod
async def async_from_zapier_nla_wrapper(
cls, zapier_nla_wrapper: ZapierNLAWrapper
) -> "ZapierToolkit":
"""Async create a toolkit from a ZapierNLAWrapper.
Args:
zapier_nla_wrapper: ZapierNLAWrapper. The Zapier NLA wrapper.
Returns:
ZapierToolkit. The Zapier toolkit.
"""
actions = await zapier_nla_wrapper.alist()
tools = [
ZapierNLARunAction(
action_id=action["id"],
zapier_description=action["description"],
params_schema=action["params"],
api_wrapper=zapier_nla_wrapper,
)
for action in actions
]
return cls(tools=tools) # type: ignore[arg-type]
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
warn_deprecated(
since="0.0.319",
message=(
"This tool will be deprecated on 2023-11-17. See "
"<https://nla.zapier.com/sunset/> for details"
),
)
return self.tools

View File

@@ -1,3 +0,0 @@
from langchain_community.agents.openai_assistant.base import OpenAIAssistantV2Runnable
__all__ = ["OpenAIAssistantV2Runnable"]

View File

@@ -1,628 +0,0 @@
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Sequence,
Type,
Union,
)
from langchain.agents.openai_assistant.base import OpenAIAssistantRunnable, OutputType
from langchain_core._api import beta
from langchain_core.callbacks import CallbackManager
from langchain_core.load import dumpd
from langchain_core.runnables import RunnableConfig, ensure_config
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
if TYPE_CHECKING:
import openai
from openai._types import NotGiven
from openai.types.beta.assistant import ToolResources as AssistantToolResources
def _get_openai_client() -> openai.OpenAI:
"""Get the OpenAI client.
Returns:
openai.OpenAI: OpenAI client
Raises:
ImportError: If `openai` is not installed.
AttributeError: If the installed `openai` version is not compatible.
"""
try:
import openai
return openai.OpenAI(default_headers={"OpenAI-Beta": "assistants=v2"})
except ImportError as e:
raise ImportError(
"Unable to import openai, please install with `pip install openai`."
) from e
except AttributeError as e:
raise AttributeError(
"Please make sure you are using a v1.23-compatible version of openai. You "
'can install with `pip install "openai>=1.23"`.'
) from e
def _get_openai_async_client() -> openai.AsyncOpenAI:
"""Get the async OpenAI client.
Returns:
openai.AsyncOpenAI: Async OpenAI client
Raises:
ImportError: If `openai` is not installed.
AttributeError: If the installed `openai` version is not compatible.
"""
try:
import openai
return openai.AsyncOpenAI(default_headers={"OpenAI-Beta": "assistants=v2"})
except ImportError as e:
raise ImportError(
"Unable to import openai, please install with `pip install openai`."
) from e
except AttributeError as e:
raise AttributeError(
"Please make sure you are using a v1.23-compatible version of openai. You "
'can install with `pip install "openai>=1.23"`.'
) from e
def _convert_file_ids_into_attachments(file_ids: list) -> list:
"""Convert file_ids into attachments
File search and Code interpreter will be turned on by default.
Args:
file_ids (list): List of file_ids that need to be converted into attachments.
Returns:
list: List of attachments converted from file_ids.
"""
attachments = []
for id in file_ids:
attachments.append(
{
"file_id": id,
"tools": [{"type": "file_search"}, {"type": "code_interpreter"}],
}
)
return attachments
def _is_assistants_builtin_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> bool:
"""Determine if tool corresponds to OpenAI Assistants built-in.
Args:
tool (Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]):
Tool that needs to be determined.
Returns:
A boolean response of true or false indicating if the tool corresponds to
OpenAI Assistants built-in.
"""
assistants_builtin_tools = ("code_interpreter", "retrieval", "file_search")
return (
isinstance(tool, dict)
and ("type" in tool)
and (tool["type"] in assistants_builtin_tools)
)
def _get_assistants_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> Dict[str, Any]:
"""Convert a raw function/class to an OpenAI tool.
Note that OpenAI assistants supports several built-in tools,
such as "code_interpreter" and "retrieval."
Args:
tool (Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]):
Tools or functions that need to be converted to OpenAI tools.
Returns:
Dict[str, Any]: A dictionary of tools that are converted into OpenAI tools.
"""
if _is_assistants_builtin_tool(tool):
return tool # type: ignore[return-value]
else:
return convert_to_openai_tool(tool)
@beta()
class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
"""Run an OpenAI Assistant.
Attributes:
client (Any): OpenAI or AzureOpenAI client.
async_client (Any): Async OpenAI or AzureOpenAI client.
assistant_id (str): OpenAI assistant ID.
check_every_ms (float): Frequency to check progress in milliseconds.
as_agent (bool): Whether to use the assistant as a LangChain agent.
Example using OpenAI tools:
.. code-block:: python
from langchain.agents.openai_assistant import OpenAIAssistantV2Runnable
assistant = OpenAIAssistantV2Runnable.create_assistant(
name="math assistant",
instructions="You are a personal math tutor. Write and run code to answer math questions.",
tools=[{"type": "code_interpreter"}],
model="gpt-4-1106-preview"
)
output = assistant.invoke({"content": "What's 10 - 4 raised to the 2.7"})
Example using custom tools and AgentExecutor:
.. code-block:: python
from langchain.agents.openai_assistant import OpenAIAssistantV2Runnable
from langchain.agents import AgentExecutor
from langchain.tools import E2BDataAnalysisTool
tools = [E2BDataAnalysisTool(api_key="...")]
agent = OpenAIAssistantV2Runnable.create_assistant(
name="langchain assistant e2b tool",
instructions="You are a personal math tutor. Write and run code to answer math questions.",
tools=tools,
model="gpt-4-1106-preview",
as_agent=True
)
agent_executor = AgentExecutor(agent=agent, tools=tools)
agent_executor.invoke({"content": "Analyze the data..."})
Example using custom tools and custom execution:
.. code-block:: python
from langchain.agents.openai_assistant import OpenAIAssistantV2Runnable
from langchain.agents import AgentExecutor
from langchain_core.agents import AgentFinish
from langchain.tools import E2BDataAnalysisTool
tools = [E2BDataAnalysisTool(api_key="...")]
agent = OpenAIAssistantV2Runnable.create_assistant(
name="langchain assistant e2b tool",
instructions="You are a personal math tutor. Write and run code to answer math questions.",
tools=tools,
model="gpt-4-1106-preview",
as_agent=True
)
def execute_agent(agent, tools, input):
tool_map = {tool.name: tool for tool in tools}
response = agent.invoke(input)
while not isinstance(response, AgentFinish):
tool_outputs = []
for action in response:
tool_output = tool_map[action.tool].invoke(action.tool_input)
tool_outputs.append({"output": tool_output, "tool_call_id": action.tool_call_id})
response = agent.invoke(
{
"tool_outputs": tool_outputs,
"run_id": action.run_id,
"thread_id": action.thread_id
}
)
return response
response = execute_agent(agent, tools, {"content": "What's 10 - 4 raised to the 2.7"})
next_response = execute_agent(agent, tools, {"content": "now add 17.241", "thread_id": response.thread_id})
""" # noqa: E501
client: Any = Field(default_factory=_get_openai_client)
"""OpenAI or AzureOpenAI client."""
async_client: Any = None
"""OpenAI or AzureOpenAI async client."""
assistant_id: str
"""OpenAI assistant id."""
check_every_ms: float = 1_000.0
"""Frequency with which to check run progress in milliseconds."""
as_agent: bool = False
"""Use as a LangChain agent, compatible with the AgentExecutor."""
@model_validator(mode="after")
def validate_async_client(self) -> Self:
"""Validate that the async client is set, otherwise initialize it."""
if self.async_client is None:
import openai
api_key = self.client.api_key
self.async_client = openai.AsyncOpenAI(api_key=api_key)
return self
@classmethod
def create_assistant(
cls,
name: str,
instructions: str,
tools: Sequence[Union[BaseTool, dict]],
model: str,
*,
model_kwargs: dict[str, float] = {},
client: Optional[Union[openai.OpenAI, openai.AzureOpenAI]] = None,
tool_resources: Optional[Union[AssistantToolResources, dict, NotGiven]] = None,
extra_body: Optional[object] = None,
**kwargs: Any,
) -> OpenAIAssistantRunnable:
"""Create an OpenAI Assistant and instantiate the Runnable.
Args:
name (str): Assistant name.
instructions (str): Assistant instructions.
tools (Sequence[Union[BaseTool, dict]]): Assistant tools. Can be passed
in OpenAI format or as BaseTools.
tool_resources (Optional[Union[AssistantToolResources, dict, NotGiven]]):
Assistant tool resources. Can be passed in OpenAI format.
model (str): Assistant model to use.
client (Optional[Union[openai.OpenAI, openai.AzureOpenAI]]): OpenAI or
AzureOpenAI client. Will create default OpenAI client (Assistant v2)
if not specified.
model_kwargs: Additional model arguments. Only available for temperature
and top_p parameters.
extra_body: Additional body parameters to be passed to the assistant.
Returns:
OpenAIAssistantRunnable: The configured assistant runnable.
"""
client = client or _get_openai_client()
if tool_resources is None:
from openai._types import NOT_GIVEN
tool_resources = NOT_GIVEN
assistant = client.beta.assistants.create(
name=name,
instructions=instructions,
tools=[_get_assistants_tool(tool) for tool in tools],
tool_resources=tool_resources,
model=model,
extra_body=extra_body,
**model_kwargs,
)
return cls(assistant_id=assistant.id, client=client, **kwargs)
def invoke(
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> OutputType:
"""Invoke the assistant.
Args:
input (dict): Runnable input dict that can have:
content: User message when starting a new run.
thread_id: Existing thread to use.
run_id: Existing run to use. Should only be supplied when providing
the tool output for a required action after an initial invocation.
file_ids: (deprecated) File ids to include in new run. Use
'attachments' instead
attachments: Assistant files to include in new run. (v2 API).
message_metadata: Metadata to associate with new message.
thread_metadata: Metadata to associate with new thread. Only relevant
when new thread being created.
instructions: Additional run instructions.
model: Override Assistant model for this run.
tools: Override Assistant tools for this run.
tool_resources: Override Assistant tool resources for this run (v2 API).
run_metadata: Metadata to associate with new run.
config (Optional[RunnableConfig]): Configuration for the run.
Returns:
OutputType: If self.as_agent, will return
Union[List[OpenAIAssistantAction], OpenAIAssistantFinish]. Otherwise,
will return OpenAI types
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
Raises:
BaseException: If an error occurs during the invocation.
"""
config = ensure_config(config)
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
)
files = _convert_file_ids_into_attachments(kwargs.get("file_ids", []))
attachments = kwargs.get("attachments", []) + files
try:
# Being run within AgentExecutor and there are tool outputs to submit.
if self.as_agent and input.get("intermediate_steps"):
tool_outputs = self._parse_intermediate_steps(
input["intermediate_steps"]
)
run = self.client.beta.threads.runs.submit_tool_outputs(**tool_outputs)
# Starting a new thread and a new run.
elif "thread_id" not in input:
thread = {
"messages": [
{
"role": "user",
"content": input["content"],
"attachments": attachments,
"metadata": input.get("message_metadata"),
}
],
"metadata": input.get("thread_metadata"),
}
run = self._create_thread_and_run(input, thread)
# Starting a new run in an existing thread.
elif "run_id" not in input:
_ = self.client.beta.threads.messages.create(
input["thread_id"],
content=input["content"],
role="user",
attachments=attachments,
metadata=input.get("message_metadata"),
)
run = self._create_run(input)
# Submitting tool outputs to an existing run, outside the AgentExecutor
# framework.
else:
run = self.client.beta.threads.runs.submit_tool_outputs(**input)
run = self._wait_for_run(run.id, run.thread_id)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
try:
response = self._get_response(run)
except BaseException as e:
run_manager.on_chain_error(e, metadata=run.dict())
raise e
else:
run_manager.on_chain_end(response)
return response
@classmethod
async def acreate_assistant(
cls,
name: str,
instructions: str,
tools: Sequence[Union[BaseTool, dict]],
model: str,
*,
async_client: Optional[
Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]
] = None,
tool_resources: Optional[Union[AssistantToolResources, dict, NotGiven]] = None,
**kwargs: Any,
) -> OpenAIAssistantRunnable:
"""Create an AsyncOpenAI Assistant and instantiate the Runnable.
Args:
name (str): Assistant name.
instructions (str): Assistant instructions.
tools (Sequence[Union[BaseTool, dict]]): Assistant tools. Can be passed
in OpenAI format or as BaseTools.
tool_resources (Optional[Union[AssistantToolResources, dict, NotGiven]]):
Assistant tool resources. Can be passed in OpenAI format.
model (str): Assistant model to use.
async_client (Optional[Union[openai.OpenAI, openai.AzureOpenAI]]): OpenAI or
AzureOpenAI async client. Will create default async_client if not specified.
Returns:
AsyncOpenAIAssistantRunnable: The configured assistant runnable.
"""
async_client = async_client or _get_openai_async_client()
if tool_resources is None:
from openai._types import NOT_GIVEN
tool_resources = NOT_GIVEN
openai_tools = [_get_assistants_tool(tool) for tool in tools]
assistant = await async_client.beta.assistants.create(
name=name,
instructions=instructions,
tools=openai_tools,
tool_resources=tool_resources,
model=model,
)
return cls(assistant_id=assistant.id, async_client=async_client, **kwargs)
async def ainvoke(
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> OutputType:
"""Async invoke assistant.
Args:
input (dict): Runnable input dict that can have:
content: User message when starting a new run.
thread_id: Existing thread to use.
run_id: Existing run to use. Should only be supplied when providing
the tool output for a required action after an initial invocation.
file_ids: (deprecated) File ids to include in new run. Use
'attachments' instead
attachments: Assistant files to include in new run. (v2 API).
message_metadata: Metadata to associate with new message.
thread_metadata: Metadata to associate with new thread. Only relevant
when new thread being created.
instructions: Additional run instructions.
model: Override Assistant model for this run.
tools: Override Assistant tools for this run.
tool_resources: Override Assistant tool resources for this run (v2 API).
run_metadata: Metadata to associate with new run.
config (Optional[RunnableConfig]): Configuration for the run.
Returns:
OutputType: If self.as_agent, will return
Union[List[OpenAIAssistantAction], OpenAIAssistantFinish]. Otherwise,
will return OpenAI types
Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]].
Raises:
BaseException: If an error occurs during the invocation.
"""
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name") or self.get_name()
)
files = _convert_file_ids_into_attachments(kwargs.get("file_ids", []))
attachments = kwargs.get("attachments", []) + files
try:
# Being run within AgentExecutor and there are tool outputs to submit.
if self.as_agent and input.get("intermediate_steps"):
tool_outputs = self._parse_intermediate_steps(
input["intermediate_steps"]
)
run = await self.async_client.beta.threads.runs.submit_tool_outputs(
**tool_outputs
)
# Starting a new thread and a new run.
elif "thread_id" not in input:
thread = {
"messages": [
{
"role": "user",
"content": input["content"],
"attachments": attachments,
"metadata": input.get("message_metadata"),
}
],
"metadata": input.get("thread_metadata"),
}
run = await self._acreate_thread_and_run(input, thread)
# Starting a new run in an existing thread.
elif "run_id" not in input:
_ = await self.async_client.beta.threads.messages.create(
input["thread_id"],
content=input["content"],
role="user",
attachments=attachments,
metadata=input.get("message_metadata"),
)
run = await self._acreate_run(input)
# Submitting tool outputs to an existing run, outside the AgentExecutor
# framework.
else:
run = await self.async_client.beta.threads.runs.submit_tool_outputs(
**input
)
run = await self._await_for_run(run.id, run.thread_id)
except BaseException as e:
run_manager.on_chain_error(e)
raise e
try:
response = self._get_response(run)
except BaseException as e:
run_manager.on_chain_error(e, metadata=run.dict())
raise e
else:
run_manager.on_chain_end(response)
return response
def _create_run(self, input: dict) -> Any:
"""Create a new run within an existing thread.
Args:
input (dict): The input data for the new run.
Returns:
Any: The created run object.
"""
allowed_assistant_params = (
"instructions",
"model",
"tools",
"tool_resources",
"run_metadata",
"truncation_strategy",
"max_prompt_tokens",
)
params = {k: v for k, v in input.items() if k in allowed_assistant_params}
return self.client.beta.threads.runs.create(
input["thread_id"],
assistant_id=self.assistant_id,
**params,
)
def _create_thread_and_run(self, input: dict, thread: dict) -> Any:
"""Create a new thread and run.
Args:
input (dict): The input data for the run.
thread (dict): The thread data to create.
Returns:
Any: The created thread and run.
"""
params = {
k: v
for k, v in input.items()
if k in ("instructions", "model", "tools", "run_metadata")
}
if tool_resources := input.get("tool_resources"):
thread["tool_resources"] = tool_resources
run = self.client.beta.threads.create_and_run(
assistant_id=self.assistant_id,
thread=thread,
**params,
)
return run
async def _acreate_run(self, input: dict) -> Any:
"""Asynchronously create a new run within an existing thread.
Args:
input (dict): The input data for the new run.
Returns:
Any: The created run object.
"""
params = {
k: v
for k, v in input.items()
if k in ("instructions", "model", "tools", "tool_resources", "run_metadata")
}
return await self.async_client.beta.threads.runs.create(
input["thread_id"],
assistant_id=self.assistant_id,
**params,
)
async def _acreate_thread_and_run(self, input: dict, thread: dict) -> Any:
"""Asynchronously create a new thread and run simultaneously.
Args:
input (dict): The input data for the run.
thread (dict): The thread data to create.
Returns:
Any: The created thread and run.
"""
params = {
k: v
for k, v in input.items()
if k in ("instructions", "model", "tools", "run_metadata")
}
if tool_resources := input.get("tool_resources"):
thread["tool_resources"] = tool_resources
run = await self.async_client.beta.threads.create_and_run(
assistant_id=self.assistant_id,
thread=thread,
**params,
)
return run

File diff suppressed because it is too large Load Diff

View File

@@ -1,157 +0,0 @@
"""**Callback handlers** allow listening to events in LangChain.
**Class hierarchy:**
.. code-block::
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
"""
import importlib
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from langchain_community.callbacks.aim_callback import (
AimCallbackHandler,
)
from langchain_community.callbacks.argilla_callback import (
ArgillaCallbackHandler,
)
from langchain_community.callbacks.arize_callback import (
ArizeCallbackHandler,
)
from langchain_community.callbacks.arthur_callback import (
ArthurCallbackHandler,
)
from langchain_community.callbacks.clearml_callback import (
ClearMLCallbackHandler,
)
from langchain_community.callbacks.comet_ml_callback import (
CometCallbackHandler,
)
from langchain_community.callbacks.context_callback import (
ContextCallbackHandler,
)
from langchain_community.callbacks.fiddler_callback import (
FiddlerCallbackHandler,
)
from langchain_community.callbacks.flyte_callback import (
FlyteCallbackHandler,
)
from langchain_community.callbacks.human import (
HumanApprovalCallbackHandler,
)
from langchain_community.callbacks.infino_callback import (
InfinoCallbackHandler,
)
from langchain_community.callbacks.labelstudio_callback import (
LabelStudioCallbackHandler,
)
from langchain_community.callbacks.llmonitor_callback import (
LLMonitorCallbackHandler,
)
from langchain_community.callbacks.manager import (
get_openai_callback,
wandb_tracing_enabled,
)
from langchain_community.callbacks.mlflow_callback import (
MlflowCallbackHandler,
)
from langchain_community.callbacks.openai_info import (
OpenAICallbackHandler,
)
from langchain_community.callbacks.promptlayer_callback import (
PromptLayerCallbackHandler,
)
from langchain_community.callbacks.sagemaker_callback import (
SageMakerCallbackHandler,
)
from langchain_community.callbacks.streamlit import (
LLMThoughtLabeler,
StreamlitCallbackHandler,
)
from langchain_community.callbacks.trubrics_callback import (
TrubricsCallbackHandler,
)
from langchain_community.callbacks.upstash_ratelimit_callback import (
UpstashRatelimitError,
UpstashRatelimitHandler, # noqa: F401
)
from langchain_community.callbacks.uptrain_callback import (
UpTrainCallbackHandler,
)
from langchain_community.callbacks.wandb_callback import (
WandbCallbackHandler,
)
from langchain_community.callbacks.whylabs_callback import (
WhyLabsCallbackHandler,
)
_module_lookup = {
"AimCallbackHandler": "langchain_community.callbacks.aim_callback",
"ArgillaCallbackHandler": "langchain_community.callbacks.argilla_callback",
"ArizeCallbackHandler": "langchain_community.callbacks.arize_callback",
"ArthurCallbackHandler": "langchain_community.callbacks.arthur_callback",
"ClearMLCallbackHandler": "langchain_community.callbacks.clearml_callback",
"CometCallbackHandler": "langchain_community.callbacks.comet_ml_callback",
"ContextCallbackHandler": "langchain_community.callbacks.context_callback",
"FiddlerCallbackHandler": "langchain_community.callbacks.fiddler_callback",
"FlyteCallbackHandler": "langchain_community.callbacks.flyte_callback",
"HumanApprovalCallbackHandler": "langchain_community.callbacks.human",
"InfinoCallbackHandler": "langchain_community.callbacks.infino_callback",
"LLMThoughtLabeler": "langchain_community.callbacks.streamlit",
"LLMonitorCallbackHandler": "langchain_community.callbacks.llmonitor_callback",
"LabelStudioCallbackHandler": "langchain_community.callbacks.labelstudio_callback",
"MlflowCallbackHandler": "langchain_community.callbacks.mlflow_callback",
"OpenAICallbackHandler": "langchain_community.callbacks.openai_info",
"PromptLayerCallbackHandler": "langchain_community.callbacks.promptlayer_callback",
"SageMakerCallbackHandler": "langchain_community.callbacks.sagemaker_callback",
"StreamlitCallbackHandler": "langchain_community.callbacks.streamlit",
"TrubricsCallbackHandler": "langchain_community.callbacks.trubrics_callback",
"UpstashRatelimitError": "langchain_community.callbacks.upstash_ratelimit_callback",
"UpstashRatelimitHandler": "langchain_community.callbacks.upstash_ratelimit_callback", # noqa
"UpTrainCallbackHandler": "langchain_community.callbacks.uptrain_callback",
"WandbCallbackHandler": "langchain_community.callbacks.wandb_callback",
"WhyLabsCallbackHandler": "langchain_community.callbacks.whylabs_callback",
"get_openai_callback": "langchain_community.callbacks.manager",
"wandb_tracing_enabled": "langchain_community.callbacks.manager",
}
def __getattr__(name: str) -> Any:
if name in _module_lookup:
module = importlib.import_module(_module_lookup[name])
return getattr(module, name)
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = [
"AimCallbackHandler",
"ArgillaCallbackHandler",
"ArizeCallbackHandler",
"ArthurCallbackHandler",
"ClearMLCallbackHandler",
"CometCallbackHandler",
"ContextCallbackHandler",
"FiddlerCallbackHandler",
"FlyteCallbackHandler",
"HumanApprovalCallbackHandler",
"InfinoCallbackHandler",
"LLMThoughtLabeler",
"LLMonitorCallbackHandler",
"LabelStudioCallbackHandler",
"MlflowCallbackHandler",
"OpenAICallbackHandler",
"PromptLayerCallbackHandler",
"SageMakerCallbackHandler",
"StreamlitCallbackHandler",
"TrubricsCallbackHandler",
"UpstashRatelimitError",
"UpstashRatelimitHandler",
"UpTrainCallbackHandler",
"WandbCallbackHandler",
"WhyLabsCallbackHandler",
"get_openai_callback",
"wandb_tracing_enabled",
]

View File

@@ -1,434 +0,0 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
def import_aim() -> Any:
"""Import the aim python package and raise an error if it is not installed."""
return guard_import("aim")
class BaseMetadataCallbackHandler:
"""Callback handler for the metadata and associated function states for callbacks.
Attributes:
step (int): The current step.
starts (int): The number of times the start method has been called.
ends (int): The number of times the end method has been called.
errors (int): The number of times the error method has been called.
text_ctr (int): The number of times the text method has been called.
ignore_llm_ (bool): Whether to ignore llm callbacks.
ignore_chain_ (bool): Whether to ignore chain callbacks.
ignore_agent_ (bool): Whether to ignore agent callbacks.
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
always_verbose_ (bool): Whether to always be verbose.
chain_starts (int): The number of times the chain start method has been called.
chain_ends (int): The number of times the chain end method has been called.
llm_starts (int): The number of times the llm start method has been called.
llm_ends (int): The number of times the llm end method has been called.
llm_streams (int): The number of times the text method has been called.
tool_starts (int): The number of times the tool start method has been called.
tool_ends (int): The number of times the tool end method has been called.
agent_ends (int): The number of times the agent end method has been called.
"""
def __init__(self) -> None:
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.ignore_retriever_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return self.always_verbose_
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return self.ignore_llm_
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return self.ignore_chain_
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
def get_custom_callback_meta(self) -> Dict[str, Any]:
return {
"step": self.step,
"starts": self.starts,
"ends": self.ends,
"errors": self.errors,
"text_ctr": self.text_ctr,
"chain_starts": self.chain_starts,
"chain_ends": self.chain_ends,
"llm_starts": self.llm_starts,
"llm_ends": self.llm_ends,
"llm_streams": self.llm_streams,
"tool_starts": self.tool_starts,
"tool_ends": self.tool_ends,
"agent_ends": self.agent_ends,
}
def reset_callback_meta(self) -> None:
"""Reset the callback metadata."""
self.step = 0
self.starts = 0
self.ends = 0
self.errors = 0
self.text_ctr = 0
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.always_verbose_ = False
self.chain_starts = 0
self.chain_ends = 0
self.llm_starts = 0
self.llm_ends = 0
self.llm_streams = 0
self.tool_starts = 0
self.tool_ends = 0
self.agent_ends = 0
return None
class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to Aim.
Parameters:
repo (:obj:`str`, optional): Aim repository path or Repo object to which
Run object is bound. If skipped, default Repo is used.
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
'default' if not specified. Can be used later to query runs/sequences.
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
to disable system metrics tracking.
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
params such as installed packages, git info, environment variables, etc.
This handler will utilize the associated callback method called and formats
the input of each callback function with metadata regarding the state of LLM run
and then logs the response to Aim.
"""
def __init__(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = 10,
log_system_params: bool = True,
) -> None:
"""Initialize callback handler."""
super().__init__()
aim = import_aim()
self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
self.log_system_params = log_system_params
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
self.action_records: list = []
def setup(self, **kwargs: Any) -> None:
aim = import_aim()
if not self._run:
if self._run_hash:
self._run = aim.Run(
self._run_hash,
repo=self.repo,
system_tracking_interval=self.system_tracking_interval,
)
else:
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
if kwargs:
for key, value in kwargs.items():
self._run.set(key, value, strict=False)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
aim = import_aim()
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = {"action": "on_llm_start"}
resp.update(self.get_custom_callback_meta())
prompts_res = deepcopy(prompts)
self._run.track(
[aim.Text(prompt) for prompt in prompts_res],
name="on_llm_start",
context=resp,
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
aim = import_aim()
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = {"action": "on_llm_end"}
resp.update(self.get_custom_callback_meta())
response_res = deepcopy(response)
generated = [
aim.Text(generation.text)
for generations in response_res.generations
for generation in generations
]
self._run.track(
generated,
name="on_llm_end",
context=resp,
)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
aim = import_aim()
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = {"action": "on_chain_start"}
resp.update(self.get_custom_callback_meta())
inputs_res = deepcopy(inputs)
self._run.track(
aim.Text(inputs_res["input"]), name="on_chain_start", context=resp
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
aim = import_aim()
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = {"action": "on_chain_end"}
resp.update(self.get_custom_callback_meta())
outputs_res = deepcopy(outputs)
self._run.track(
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
aim = import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = {"action": "on_tool_start"}
resp.update(self.get_custom_callback_meta())
self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
aim = import_aim()
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = {"action": "on_tool_end"}
resp.update(self.get_custom_callback_meta())
self._run.track(aim.Text(output), name="on_tool_end", context=resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
aim = import_aim()
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = {"action": "on_agent_finish"}
resp.update(self.get_custom_callback_meta())
finish_res = deepcopy(finish)
text = "OUTPUT:\n{}\n\nLOG:\n{}".format(
finish_res.return_values["output"], finish_res.log
)
self._run.track(aim.Text(text), name="on_agent_finish", context=resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
aim = import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = {
"action": "on_agent_action",
"tool": action.tool,
}
resp.update(self.get_custom_callback_meta())
action_res = deepcopy(action)
text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format(
action_res.tool_input, action_res.log
)
self._run.track(aim.Text(text), name="on_agent_action", context=resp)
def flush_tracker(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = 10,
log_system_params: bool = True,
langchain_asset: Any = None,
reset: bool = True,
finish: bool = False,
) -> None:
"""Flush the tracker and reset the session.
Args:
repo (:obj:`str`, optional): Aim repository path or Repo object to which
Run object is bound. If skipped, default Repo is used.
experiment_name (:obj:`str`, optional): Sets Run's `experiment` property.
'default' if not specified. Can be used later to query runs/sequences.
system_tracking_interval (:obj:`int`, optional): Sets the tracking interval
in seconds for system usage metrics (CPU, Memory, etc.). Set to `None`
to disable system metrics tracking.
log_system_params (:obj:`bool`, optional): Enable/Disable logging of system
params such as installed packages, git info, environment variables, etc.
langchain_asset: The langchain asset to save.
reset: Whether to reset the session.
finish: Whether to finish the run.
Returns:
None
"""
if langchain_asset:
try:
for key, value in langchain_asset.dict().items():
self._run.set(key, value, strict=False)
except Exception:
pass
if finish or reset:
self._run.close()
self.reset_callback_meta()
if reset:
aim = import_aim()
self.repo = repo if repo else self.repo
self.experiment_name = (
experiment_name if experiment_name else self.experiment_name
)
self.system_tracking_interval = (
system_tracking_interval
if system_tracking_interval
else self.system_tracking_interval
)
self.log_system_params = (
log_system_params if log_system_params else self.log_system_params
)
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
self.action_records = []

View File

@@ -1,349 +0,0 @@
import os
import warnings
from typing import Any, Dict, List, Optional, cast
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from packaging.version import parse
class ArgillaCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs into Argilla.
Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset` in
Argilla, please visit
https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html.
workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default will be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default
will be used.
Raises:
ImportError: if the `argilla` package is not installed.
ConnectionError: if the connection to Argilla fails.
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
Examples:
>>> from langchain_community.llms import OpenAI
>>> from langchain_community.callbacks import ArgillaCallbackHandler
>>> argilla_callback = ArgillaCallbackHandler(
... dataset_name="my-dataset",
... workspace_name="my-workspace",
... api_url="http://localhost:6900",
... api_key="argilla.apikey",
... )
>>> llm = OpenAI(
... temperature=0,
... callbacks=[argilla_callback],
... verbose=True,
... openai_api_key="API_KEY_HERE",
... )
>>> llm.generate([
... "What is the best NLP-annotation tool out there? (no bias at all)",
... ])
"Argilla, no doubt about it."
"""
REPO_URL: str = "https://github.com/argilla-io/argilla"
ISSUES_URL: str = f"{REPO_URL}/issues"
BLOG_URL: str = "https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html"
DEFAULT_API_URL: str = "http://localhost:6900"
def __init__(
self,
dataset_name: str,
workspace_name: Optional[str] = None,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> None:
"""Initializes the `ArgillaCallbackHandler`.
Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset`
in Argilla, please visit
https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html.
workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default will be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default
will be used.
Raises:
ImportError: if the `argilla` package is not installed.
ConnectionError: if the connection to Argilla fails.
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
"""
super().__init__()
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
try:
import argilla as rg
self.ARGILLA_VERSION = rg.__version__
except ImportError:
raise ImportError(
"To use the Argilla callback manager you need to have the `argilla` "
"Python package installed. Please install it with `pip install argilla`"
)
# Check whether the Argilla version is compatible
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
raise ImportError(
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
"upgrade `argilla` with `pip install --upgrade argilla`."
)
# Show a warning message if Argilla will assume the default values will be used
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
warnings.warn(
(
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
f" set, it will default to `{self.DEFAULT_API_URL}`, which is the"
" default API URL in Argilla Quickstart."
),
)
api_url = self.DEFAULT_API_URL
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
self.DEFAULT_API_KEY = (
"admin.apikey"
if parse(self.ARGILLA_VERSION) < parse("1.11.0")
else "owner.apikey"
)
warnings.warn(
(
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the"
" default API key in Argilla Quickstart."
),
)
api_key = self.DEFAULT_API_KEY
# Connect to Argilla with the provided credentials, if applicable
try:
rg.init(api_key=api_key, api_url=api_url)
except Exception as e:
raise ConnectionError(
f"Could not connect to Argilla with exception: '{e}'.\n"
"Please check your `api_key` and `api_url`, and make sure that "
"the Argilla server is up and running. If the problem persists "
f"please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
# Set the Argilla variables
self.dataset_name = dataset_name
self.workspace_name = workspace_name or rg.get_workspace()
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
try:
extra_args = {}
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
warnings.warn(
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
" higher is recommended.",
UserWarning,
)
extra_args = {"with_records": False}
self.dataset = rg.FeedbackDataset.from_argilla(
name=self.dataset_name,
workspace=self.workspace_name,
**extra_args,
)
except Exception as e:
raise FileNotFoundError(
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
f"\nPlease check that the dataset with name={self.dataset_name} in the"
f" workspace={self.workspace_name} exists in advance. If you need help"
" on how to create a `langchain`-compatible `FeedbackDataset` in"
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
f" please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
supported_fields = ["prompt", "response"]
if supported_fields != [field.name for field in self.dataset.fields]:
raise ValueError(
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
f"{self.workspace_name} had fields that are not supported yet for the"
f"`langchain` integration. Supported fields are: {supported_fields},"
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
" For more information on how to create a `langchain`-compatible"
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
)
self.prompts: Dict[str, List[str]] = {}
warnings.warn(
(
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
" change based on updates to `langchain`. Please report any issues to"
f" {self.ISSUES_URL} as an `integration` issue."
),
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Save the prompts in memory when an LLM starts."""
self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log records to Argilla when an LLM ends."""
# Do nothing if there's a parent_run_id, since we will log the records when
# the chain ends
if kwargs["parent_run_id"]:
return
# Creates the records and adds them to the `FeedbackDataset`
prompts = self.prompts[str(kwargs["run_id"])]
for prompt, generations in zip(prompts, response.generations):
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": generation.text.strip(),
},
}
for generation in generations
]
)
# Pop current run from `self.runs`
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""If the key `input` is in `inputs`, then save it in `self.prompts` using
either the `parent_run_id` or the `run_id` as the key. This is done so that
we don't log the same input prompt twice, once when the LLM starts and once
when the chain starts.
"""
if "input" in inputs:
self.prompts.update(
{
str(kwargs["parent_run_id"] or kwargs["run_id"]): (
inputs["input"]
if isinstance(inputs["input"], list)
else [inputs["input"]]
)
}
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
differs if the output is a list or not.
"""
if not any(
key in self.prompts
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
):
return
prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
List, self.prompts.get(str(kwargs["run_id"]), [])
)
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, list):
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": output["text"].strip(),
},
}
for prompt, output in zip(prompts, chain_output_val)
]
)
else:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": " ".join(prompts),
"response": chain_output_val.strip(),
},
}
]
)
# Pop current run from `self.runs`
if str(kwargs["parent_run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["parent_run_id"]))
if str(kwargs["run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass

View File

@@ -1,213 +0,0 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_community.callbacks.utils import import_pandas
class ArizeCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Arize."""
def __init__(
self,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
SPACE_KEY: Optional[str] = None,
API_KEY: Optional[str] = None,
) -> None:
"""Initialize callback handler."""
super().__init__()
self.model_id = model_id
self.model_version = model_version
self.space_key = SPACE_KEY
self.api_key = API_KEY
self.prompt_records: List[str] = []
self.response_records: List[str] = []
self.prediction_ids: List[str] = []
self.pred_timestamps: List[int] = []
self.response_embeddings: List[float] = []
self.prompt_embeddings: List[float] = []
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
self.step = 0
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
from arize.pandas.logger import Client
self.generator = EmbeddingGenerator.from_use_case(
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
model_name="distilbert-base-uncased",
tokenizer_max_length=512,
batch_size=256,
)
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
raise ValueError("❌ CHANGE SPACE AND API KEYS")
else:
print("✅ Arize client setup done! Now you can start using Arize!") # noqa: T201
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
for prompt in prompts:
self.prompt_records.append(prompt.replace("\n", ""))
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pd = import_pandas()
from arize.utils.types import (
EmbeddingColumnNames,
Environments,
ModelTypes,
Schema,
)
# Safe check if 'llm_output' and 'token_usage' exist
if response.llm_output and "token_usage" in response.llm_output:
self.prompt_tokens = response.llm_output["token_usage"].get(
"prompt_tokens", 0
)
self.total_tokens = response.llm_output["token_usage"].get(
"total_tokens", 0
)
self.completion_tokens = response.llm_output["token_usage"].get(
"completion_tokens", 0
)
else:
self.prompt_tokens = self.total_tokens = self.completion_tokens = (
0 # assign default value
)
for generations in response.generations:
for generation in generations:
prompt = self.prompt_records[self.step]
self.step = self.step + 1
prompt_embedding = pd.Series(
self.generator.generate_embeddings(
text_col=pd.Series(prompt.replace("\n", " "))
).reset_index(drop=True)
)
# Assigning text to response_text instead of response
response_text = generation.text.replace("\n", " ")
response_embedding = pd.Series(
self.generator.generate_embeddings(
text_col=pd.Series(generation.text.replace("\n", " "))
).reset_index(drop=True)
)
pred_timestamp = datetime.now().timestamp()
# Define the columns and data
columns = [
"prediction_ts",
"response",
"prompt",
"response_vector",
"prompt_vector",
"prompt_token",
"completion_token",
"total_token",
]
data = [
[
pred_timestamp,
response_text,
prompt,
response_embedding[0],
prompt_embedding[0],
self.prompt_tokens,
self.total_tokens,
self.completion_tokens,
]
]
# Create the DataFrame
df = pd.DataFrame(data, columns=columns)
# Declare prompt and response columns
prompt_columns = EmbeddingColumnNames(
vector_column_name="prompt_vector", data_column_name="prompt"
)
response_columns = EmbeddingColumnNames(
vector_column_name="response_vector", data_column_name="response"
)
schema = Schema(
timestamp_column_name="prediction_ts",
tag_column_names=[
"prompt_token",
"completion_token",
"total_token",
],
prompt_column_names=prompt_columns,
response_column_names=response_columns,
)
response_from_arize = self.arize_client.log(
dataframe=df,
schema=schema,
model_id=self.model_id,
model_version=self.model_version,
model_type=ModelTypes.GENERATIVE_LLM,
environment=Environments.PRODUCTION,
)
if response_from_arize.status_code == 200:
print("✅ Successfully logged data to Arize!") # noqa: T201
else:
print(f'❌ Logging failed "{response_from_arize.text}"') # noqa: T201
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing."""
pass
def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
pass
def on_text(self, text: str, **kwargs: Any) -> None:
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
pass

View File

@@ -1,297 +0,0 @@
"""ArthurAI's Callback Handler."""
from __future__ import annotations
import os
import uuid
from collections import defaultdict
from datetime import datetime
from time import time
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
import numpy as np
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
if TYPE_CHECKING:
import arthurai
from arthurai.core.models import ArthurModel
PROMPT_TOKENS = "prompt_tokens"
COMPLETION_TOKENS = "completion_tokens"
TOKEN_USAGE = "token_usage"
FINISH_REASON = "finish_reason"
DURATION = "duration"
def _lazy_load_arthur() -> arthurai:
"""Lazy load Arthur."""
try:
import arthurai
except ImportError as e:
raise ImportError(
"To use the ArthurCallbackHandler you need the"
" `arthurai` package. Please install it with"
" `pip install arthurai`.",
e,
)
return arthurai
class ArthurCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Arthur platform.
Arthur helps enterprise teams optimize model operations
and performance at scale. The Arthur API tracks model
performance, explainability, and fairness across tabular,
NLP, and CV models. Our API is model- and platform-agnostic,
and continuously scales with complex and dynamic enterprise needs.
To learn more about Arthur, visit our website at
https://www.arthur.ai/ or read the Arthur docs at
https://docs.arthur.ai/
"""
def __init__(
self,
arthur_model: ArthurModel,
) -> None:
"""Initialize callback handler."""
super().__init__()
arthurai = _lazy_load_arthur()
Stage = arthurai.common.constants.Stage
ValueType = arthurai.common.constants.ValueType
self.arthur_model = arthur_model
# save the attributes of this model to be used when preparing
# inferences to log to Arthur in on_llm_end()
self.attr_names = set([a.name for a in self.arthur_model.get_attributes()])
self.input_attr = [
x
for x in self.arthur_model.get_attributes()
if x.stage == Stage.ModelPipelineInput
and x.value_type == ValueType.Unstructured_Text
][0].name
self.output_attr = [
x
for x in self.arthur_model.get_attributes()
if x.stage == Stage.PredictedValue
and x.value_type == ValueType.Unstructured_Text
][0].name
self.token_likelihood_attr = None
if (
len(
[
x
for x in self.arthur_model.get_attributes()
if x.value_type == ValueType.TokenLikelihoods
]
)
> 0
):
self.token_likelihood_attr = [
x
for x in self.arthur_model.get_attributes()
if x.value_type == ValueType.TokenLikelihoods
][0].name
self.run_map: DefaultDict[str, Any] = defaultdict(dict)
@classmethod
def from_credentials(
cls,
model_id: str,
arthur_url: Optional[str] = "https://app.arthur.ai",
arthur_login: Optional[str] = None,
arthur_password: Optional[str] = None,
) -> ArthurCallbackHandler:
"""Initialize callback handler from Arthur credentials.
Args:
model_id (str): The ID of the arthur model to log to.
arthur_url (str, optional): The URL of the Arthur instance to log to.
Defaults to "https://app.arthur.ai".
arthur_login (str, optional): The login to use to connect to Arthur.
Defaults to None.
arthur_password (str, optional): The password to use to connect to
Arthur. Defaults to None.
Returns:
ArthurCallbackHandler: The initialized callback handler.
"""
arthurai = _lazy_load_arthur()
ArthurAI = arthurai.ArthurAI
ResponseClientError = arthurai.common.exceptions.ResponseClientError
# connect to Arthur
if arthur_login is None:
try:
arthur_api_key = os.environ["ARTHUR_API_KEY"]
except KeyError:
raise ValueError(
"No Arthur authentication provided. Either give"
" a login to the ArthurCallbackHandler"
" or set an ARTHUR_API_KEY as an environment variable."
)
arthur = ArthurAI(url=arthur_url, access_key=arthur_api_key)
else:
if arthur_password is None:
arthur = ArthurAI(url=arthur_url, login=arthur_login)
else:
arthur = ArthurAI(
url=arthur_url, login=arthur_login, password=arthur_password
)
# get model from Arthur by the provided model ID
try:
arthur_model = arthur.get_model(model_id)
except ResponseClientError:
raise ValueError(
f"Was unable to retrieve model with id {model_id} from Arthur."
" Make sure the ID corresponds to a model that is currently"
" registered with your Arthur account."
)
return cls(arthur_model)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""On LLM start, save the input prompts"""
run_id = kwargs["run_id"]
self.run_map[run_id]["input_texts"] = prompts
self.run_map[run_id]["start_time"] = time()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""On LLM end, send data to Arthur."""
try:
import pytz
except ImportError as e:
raise ImportError(
"Could not import pytz. Please install it with 'pip install pytz'."
) from e
run_id = kwargs["run_id"]
# get the run params from this run ID,
# or raise an error if this run ID has no corresponding metadata in self.run_map
try:
run_map_data = self.run_map[run_id]
except KeyError as e:
raise KeyError(
"This function has been called with a run_id"
" that was never registered in on_llm_start()."
" Restart and try running the LLM again"
) from e
# mark the duration time between on_llm_start() and on_llm_end()
time_from_start_to_end = time() - run_map_data["start_time"]
# create inferences to log to Arthur
inferences = []
for i, generations in enumerate(response.generations):
for generation in generations:
inference = {
"partner_inference_id": str(uuid.uuid4()),
"inference_timestamp": datetime.now(tz=pytz.UTC),
self.input_attr: run_map_data["input_texts"][i],
self.output_attr: generation.text,
}
if generation.generation_info is not None:
# add finish reason to the inference
# if generation info contains a finish reason and
# if the ArthurModel was registered to monitor finish_reason
if (
FINISH_REASON in generation.generation_info
and FINISH_REASON in self.attr_names
):
inference[FINISH_REASON] = generation.generation_info[
FINISH_REASON
]
# add token likelihoods data to the inference if the ArthurModel
# was registered to monitor token likelihoods
logprobs_data = generation.generation_info["logprobs"]
if (
logprobs_data is not None
and self.token_likelihood_attr is not None
):
logprobs = logprobs_data["top_logprobs"]
likelihoods = [
{k: np.exp(v) for k, v in logprobs[i].items()}
for i in range(len(logprobs))
]
inference[self.token_likelihood_attr] = likelihoods
# add token usage counts to the inference if the
# ArthurModel was registered to monitor token usage
if (
isinstance(response.llm_output, dict)
and TOKEN_USAGE in response.llm_output
):
token_usage = response.llm_output[TOKEN_USAGE]
if (
PROMPT_TOKENS in token_usage
and PROMPT_TOKENS in self.attr_names
):
inference[PROMPT_TOKENS] = token_usage[PROMPT_TOKENS]
if (
COMPLETION_TOKENS in token_usage
and COMPLETION_TOKENS in self.attr_names
):
inference[COMPLETION_TOKENS] = token_usage[COMPLETION_TOKENS]
# add inference duration to the inference if the ArthurModel
# was registered to monitor inference duration
if DURATION in self.attr_names:
inference[DURATION] = time_from_start_to_end
inferences.append(inference)
# send inferences to arthur
self.arthur_model.send_inferences(inferences)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""On chain start, do nothing."""
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""On chain end, do nothing."""
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""On new token, pass."""
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""

View File

@@ -1,131 +0,0 @@
import threading
from typing import Any, Dict, List, Union
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
MODEL_COST_PER_1K_INPUT_TOKENS = {
"anthropic.claude-instant-v1": 0.0008,
"anthropic.claude-v2": 0.008,
"anthropic.claude-v2:1": 0.008,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.003,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.003,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.003,
"anthropic.claude-3-7-sonnet-20250219-v1:0": 0.003,
"anthropic.claude-3-haiku-20240307-v1:0": 0.00025,
"anthropic.claude-3-opus-20240229-v1:0": 0.015,
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.0008,
}
MODEL_COST_PER_1K_OUTPUT_TOKENS = {
"anthropic.claude-instant-v1": 0.0024,
"anthropic.claude-v2": 0.024,
"anthropic.claude-v2:1": 0.024,
"anthropic.claude-3-sonnet-20240229-v1:0": 0.015,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 0.015,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 0.015,
"anthropic.claude-3-7-sonnet-20250219-v1:0": 0.015,
"anthropic.claude-3-haiku-20240307-v1:0": 0.00125,
"anthropic.claude-3-opus-20240229-v1:0": 0.075,
"anthropic.claude-3-5-haiku-20241022-v1:0": 0.004,
}
def _get_anthropic_claude_token_cost(
prompt_tokens: int, completion_tokens: int, model_id: Union[str, None]
) -> float:
if model_id:
# The model ID can be a cross-region (system-defined) inference profile ID,
# which has a prefix indicating the region (e.g., 'us', 'eu') but
# shares the same token costs as the "base model".
# By extracting the "base model ID", by taking the last two segments
# of the model ID, we can map cross-region inference profile IDs to
# their corresponding cost entries.
base_model_id = model_id.split(".")[-2] + "." + model_id.split(".")[-1]
else:
base_model_id = None
"""Get the cost of tokens for the Claude model."""
if base_model_id not in MODEL_COST_PER_1K_INPUT_TOKENS:
raise ValueError(
f"Unknown model: {model_id}. Please provide a valid Anthropic model name."
"Known models are: " + ", ".join(MODEL_COST_PER_1K_INPUT_TOKENS.keys())
)
return (prompt_tokens / 1000) * MODEL_COST_PER_1K_INPUT_TOKENS[base_model_id] + (
completion_tokens / 1000
) * MODEL_COST_PER_1K_OUTPUT_TOKENS[base_model_id]
class BedrockAnthropicTokenUsageCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks bedrock anthropic info."""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0
def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}"
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
return None
if "usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None
# compute tokens and cost for this request
token_usage = response.llm_output["usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
total_tokens = token_usage.get("total_tokens", 0)
model_id = response.llm_output.get("model_id", None)
total_cost = _get_anthropic_claude_token_cost(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
model_id=model_id,
)
# update shared state behind lock
with self._lock:
self.total_cost += total_cost
self.total_tokens += total_tokens
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.successful_requests += 1
def __copy__(self) -> "BedrockAnthropicTokenUsageCallbackHandler":
"""Return a copy of the callback handler."""
return self
def __deepcopy__(self, memo: Any) -> "BedrockAnthropicTokenUsageCallbackHandler":
"""Return a deep copy of the callback handler."""
return self

View File

@@ -1,518 +0,0 @@
from __future__ import annotations
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
load_json,
)
if TYPE_CHECKING:
import pandas as pd
def import_clearml() -> Any:
"""Import the clearml python package and raise an error if it is not installed."""
return guard_import("clearml")
class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""Callback Handler that logs to ClearML.
Parameters:
job_type (str): The type of clearml task such as "inference", "testing" or "qc"
project_name (str): The clearml project name
tags (list): Tags to add to the task
task_name (str): Name of the clearml task
visualize (bool): Whether to visualize the run.
complexity_metrics (bool): Whether to log complexity metrics
stream_logs (bool): Whether to stream callback actions to ClearML
This handler will utilize the associated callback method and formats
the input of each callback function with metadata regarding the state of LLM run,
and adds the response to the list of records for both the {method}_records and
action. It then logs the response to the ClearML console.
"""
def __init__(
self,
task_type: Optional[str] = "inference",
project_name: Optional[str] = "langchain_callback_demo",
tags: Optional[Sequence] = None,
task_name: Optional[str] = None,
visualize: bool = False,
complexity_metrics: bool = False,
stream_logs: bool = False,
) -> None:
"""Initialize callback handler."""
clearml = import_clearml()
spacy = import_spacy()
super().__init__()
self.task_type = task_type
self.project_name = project_name
self.tags = tags
self.task_name = task_name
self.visualize = visualize
self.complexity_metrics = complexity_metrics
self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory()
# Check if ClearML task already exists (e.g. in pipeline)
if clearml.Task.current_task():
self.task = clearml.Task.current_task()
else:
self.task = clearml.Task.init(
task_type=self.task_type,
project_name=self.project_name,
tags=self.tags,
task_name=self.task_name,
output_uri=True,
)
self.logger = self.task.get_logger()
warning = (
"The clearml callback is currently in beta and is subject to change "
"based on updates to `langchain`. Please report any issues to "
"https://github.com/allegroai/clearml/issues with the tag `langchain`."
)
self.logger.report_text(warning, level=30, print_console=True)
self.callback_columns: list = []
self.action_records: list = []
self.complexity_metrics = complexity_metrics
self.visualize = visualize
self.nlp = spacy.load("en_core_web_sm")
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts."""
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
for prompt in prompts:
prompt_resp = deepcopy(resp)
prompt_resp["prompts"] = prompt
self.on_llm_start_records.append(prompt_resp)
self.action_records.append(prompt_resp)
if self.stream_logs:
self.logger.report_text(prompt_resp)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
self.step += 1
self.llm_streams += 1
resp = self._init_resp()
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.get_custom_callback_meta())
self.on_llm_token_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.get_custom_callback_meta())
for generations in response.generations:
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(self.analyze_text(generation.text))
self.on_llm_end_records.append(generation_resp)
self.action_records.append(generation_resp)
if self.stream_logs:
self.logger.report_text(generation_resp)
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = inputs.get("input", inputs.get("human_input"))
if isinstance(chain_input, str):
input_resp = deepcopy(resp)
input_resp["input"] = chain_input
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.logger.report_text(input_resp)
elif isinstance(chain_input, list):
for inp in chain_input:
input_resp = deepcopy(resp)
input_resp.update(inp)
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.logger.report_text(input_resp)
else:
raise ValueError("Unexpected data format provided!")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_chain_end",
"outputs": outputs.get("output", outputs.get("text")),
}
)
resp.update(self.get_custom_callback_meta())
self.on_chain_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_tool_start", "input_str": input_str})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
self.on_tool_start_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""Run when tool ends running."""
output = str(output)
self.step += 1
self.tool_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_tool_end", "output": output})
resp.update(self.get_custom_callback_meta())
self.on_tool_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1
def on_text(self, text: str, **kwargs: Any) -> None:
"""
Run when agent is ending.
"""
self.step += 1
self.text_ctr += 1
resp = self._init_resp()
resp.update({"action": "on_text", "text": text})
resp.update(self.get_custom_callback_meta())
self.on_text_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run when agent ends running."""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_finish_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_action_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
def analyze_text(self, text: str) -> dict:
"""Analyze text using textstat and spacy.
Parameters:
text (str): The text to analyze.
Returns:
(dict): A dictionary containing the complexity metrics.
"""
resp = {}
textstat = import_textstat()
spacy = import_spacy()
if self.complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(
text
),
"dale_chall_readability_score": textstat.dale_chall_readability_score(
text
),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update(text_complexity_metrics)
if self.visualize and self.nlp and self.temp_dir.name is not None:
doc = self.nlp(text)
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
dep_output_path = Path(
self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
)
dep_output_path.open("w", encoding="utf-8").write(dep_out)
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
ent_output_path = Path(
self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
)
ent_output_path.open("w", encoding="utf-8").write(ent_out)
self.logger.report_media(
"Dependencies Plot", text, local_path=dep_output_path
)
self.logger.report_media("Entities Plot", text, local_path=ent_output_path)
return resp
@staticmethod
def _build_llm_df(
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
) -> pd.DataFrame:
base_df_fields = [field for field in base_df_fields if field in base_df]
rename_map = {
map_entry_k: map_entry_v
for map_entry_k, map_entry_v in rename_map.items()
if map_entry_k in base_df_fields
}
llm_df = base_df[base_df_fields].dropna(axis=1)
if rename_map:
llm_df = llm_df.rename(rename_map, axis=1)
return llm_df
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
base_df=on_llm_end_records_df,
base_df_fields=["step", "prompts"]
+ (["name"] if "name" in on_llm_end_records_df else ["id"]),
rename_map={"step": "prompt_step"},
)
complexity_metrics_columns = []
visualizations_columns: List = []
if self.complexity_metrics:
complexity_metrics_columns = [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
"text_standard",
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
on_llm_end_records_df,
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns,
{"step": "output_step", "text": "output"},
)
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
return session_analysis_df
def flush_tracker(
self,
name: Optional[str] = None,
langchain_asset: Any = None,
finish: bool = False,
) -> None:
"""Flush the tracker and setup the session.
Everything after this will be a new table.
Args:
name: Name of the performed session so far so it is identifiable
langchain_asset: The langchain asset to save.
finish: Whether to finish the run.
Returns:
None
"""
pd = import_pandas()
clearml = import_clearml()
# Log the action records
self.logger.report_table(
"Action Records", name, table_plot=pd.DataFrame(self.action_records)
)
# Session analysis
session_analysis_df = self._create_session_analysis_df()
self.logger.report_table(
"Session Analysis", name, table_plot=session_analysis_df
)
if self.stream_logs:
self.logger.report_text(
{
"action_records": pd.DataFrame(self.action_records),
"session_analysis": session_analysis_df,
}
)
if langchain_asset:
langchain_asset_path = Path(self.temp_dir.name, "model.json")
try:
langchain_asset.save(langchain_asset_path)
# Create output model and connect it to the task
output_model = clearml.OutputModel(
task=self.task, config_text=load_json(langchain_asset_path)
)
output_model.update_weights(
weights_filename=str(langchain_asset_path),
auto_delete_file=False,
target_filename=name,
)
except ValueError:
langchain_asset.save_agent(langchain_asset_path)
output_model = clearml.OutputModel(
task=self.task, config_text=load_json(langchain_asset_path)
)
output_model.update_weights(
weights_filename=str(langchain_asset_path),
auto_delete_file=False,
target_filename=name,
)
except NotImplementedError as e:
print("Could not save model.") # noqa: T201
print(repr(e)) # noqa: T201
pass
# Cleanup after adding everything to ClearML
self.task.flush(wait_for_uploads=True)
self.temp_dir.cleanup()
self.temp_dir = tempfile.TemporaryDirectory()
self.reset_callback_meta()
if finish:
self.task.close()

Some files were not shown because too many files have changed in this diff Show More