mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-08 18:19:21 +00:00
Compare commits
7 Commits
v0.0.158
...
dev2049/pg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64d1e2042e | ||
|
|
0a7245f721 | ||
|
|
6032a051e9 | ||
|
|
fea639c1fc | ||
|
|
2f087d63af | ||
|
|
cc068f1b77 | ||
|
|
ac0a9d02bd |
42
.devcontainer/Dockerfile
Normal file
42
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,42 @@
|
||||
# This is a Dockerfile for Developer Container
|
||||
|
||||
# Use the Python base image
|
||||
ARG VARIANT="3.11-bullseye"
|
||||
FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT} AS langchain-dev-base
|
||||
|
||||
USER vscode
|
||||
|
||||
# Define the version of Poetry to install (default is 1.4.2)
|
||||
# Define the directory of python virtual environment
|
||||
ARG PYTHON_VIRTUALENV_HOME=/home/vscode/langchain-py-env \
|
||||
POETRY_VERSION=1.4.2
|
||||
|
||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=false \
|
||||
POETRY_NO_INTERACTION=true
|
||||
|
||||
# Create a Python virtual environment for Poetry and install it
|
||||
RUN python3 -m venv ${PYTHON_VIRTUALENV_HOME} && \
|
||||
$PYTHON_VIRTUALENV_HOME/bin/pip install --upgrade pip && \
|
||||
$PYTHON_VIRTUALENV_HOME/bin/pip install poetry==${POETRY_VERSION}
|
||||
|
||||
ENV PATH="$PYTHON_VIRTUALENV_HOME/bin:$PATH" \
|
||||
VIRTUAL_ENV=$PYTHON_VIRTUALENV_HOME
|
||||
|
||||
# Setup for bash
|
||||
RUN poetry completions bash >> /home/vscode/.bash_completion && \
|
||||
echo "export PATH=$PYTHON_VIRTUALENV_HOME/bin:$PATH" >> ~/.bashrc
|
||||
|
||||
# Set the working directory for the app
|
||||
WORKDIR /workspaces/langchain
|
||||
|
||||
# Use a multi-stage build to install dependencies
|
||||
FROM langchain-dev-base AS langchain-dev-dependencies
|
||||
|
||||
ARG PYTHON_VIRTUALENV_HOME
|
||||
|
||||
# Copy only the dependency files for installation
|
||||
COPY pyproject.toml poetry.lock poetry.toml ./
|
||||
|
||||
# Install the Poetry dependencies (this layer will be cached as long as the dependencies don't change)
|
||||
RUN poetry install --no-interaction --no-ansi --with dev,test,docs
|
||||
|
||||
33
.devcontainer/devcontainer.json
Normal file
33
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,33 @@
|
||||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile
|
||||
{
|
||||
"dockerComposeFile": "./docker-compose.yaml",
|
||||
"service": "langchain",
|
||||
"workspaceFolder": "/workspaces/langchain",
|
||||
"name": "langchain",
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.python"
|
||||
],
|
||||
"settings": {
|
||||
"python.defaultInterpreterPath": "/home/vscode/langchain-py-env/bin/python3.11"
|
||||
}
|
||||
}
|
||||
|
||||
},
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
"features": {},
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
// "forwardPorts": [],
|
||||
|
||||
// Uncomment the next line to run commands after the container is created.
|
||||
// "postCreateCommand": "cat /etc/os-release",
|
||||
|
||||
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
|
||||
// "remoteUser": "devcontainer"
|
||||
"remoteUser": "vscode",
|
||||
"overrideCommand": true
|
||||
}
|
||||
31
.devcontainer/docker-compose.yaml
Normal file
31
.devcontainer/docker-compose.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
version: '3'
|
||||
services:
|
||||
langchain:
|
||||
build:
|
||||
dockerfile: .devcontainer/Dockerfile
|
||||
context: ../
|
||||
volumes:
|
||||
- ../:/workspaces/langchain
|
||||
networks:
|
||||
- langchain-network
|
||||
# environment:
|
||||
# MONGO_ROOT_USERNAME: root
|
||||
# MONGO_ROOT_PASSWORD: example123
|
||||
# depends_on:
|
||||
# - mongo
|
||||
# mongo:
|
||||
# image: mongo
|
||||
# restart: unless-stopped
|
||||
# environment:
|
||||
# MONGO_INITDB_ROOT_USERNAME: root
|
||||
# MONGO_INITDB_ROOT_PASSWORD: example123
|
||||
# ports:
|
||||
# - "27017:27017"
|
||||
# networks:
|
||||
# - langchain-network
|
||||
|
||||
networks:
|
||||
langchain-network:
|
||||
driver: bridge
|
||||
|
||||
|
||||
106
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
106
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
@@ -0,0 +1,106 @@
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve LangChain
|
||||
labels: ["02 Bug Report"]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thank you for taking the time to file a bug report. Before creating a new
|
||||
issue, please make sure to take a few moments to check the issue tracker
|
||||
for existing issues about the bug.
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: Please share your system info with us.
|
||||
placeholder: LangChain version, platform, python version, ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: who-can-help
|
||||
attributes:
|
||||
label: Who can help?
|
||||
description: |
|
||||
Your issue will be replied to more quickly if you can figure out the right person to tag with @
|
||||
If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
|
||||
|
||||
The core maintainers strive to read all issues, but tagging them will help them prioritize.
|
||||
|
||||
Please tag fewer than 3 people.
|
||||
|
||||
@hwchase17 - project lead
|
||||
|
||||
Tracing / Callbacks
|
||||
- @agola11
|
||||
|
||||
Async
|
||||
- @agola11
|
||||
|
||||
DataLoader Abstractions
|
||||
- @eyurtsev
|
||||
|
||||
LLM/Chat Wrappers
|
||||
- @hwchase17
|
||||
- @agola11
|
||||
|
||||
Tools / Toolkits
|
||||
- @vowelparrot
|
||||
|
||||
placeholder: "@Username ..."
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: "The problem arises when using:"
|
||||
options:
|
||||
- label: "The official example notebooks/scripts"
|
||||
- label: "My own modified scripts"
|
||||
|
||||
- type: checkboxes
|
||||
id: related-components
|
||||
attributes:
|
||||
label: Related Components
|
||||
description: "Select the components related to the issue (if applicable):"
|
||||
options:
|
||||
- label: "LLMs/Chat Models"
|
||||
- label: "Embedding Models"
|
||||
- label: "Prompts / Prompt Templates / Prompt Selectors"
|
||||
- label: "Output Parsers"
|
||||
- label: "Document Loaders"
|
||||
- label: "Vector Stores / Retrievers"
|
||||
- label: "Memory"
|
||||
- label: "Agents / Agent Executors"
|
||||
- label: "Tools / Toolkits"
|
||||
- label: "Chains"
|
||||
- label: "Callbacks/Tracing"
|
||||
- label: "Async"
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
Please provide a [code sample](https://stackoverflow.com/help/minimal-reproducible-example) that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
|
||||
If you have code snippets, error messages, stack traces please provide them here as well.
|
||||
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
Avoid screenshots when possible, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "A clear and concise description of what you would expect to happen."
|
||||
6
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
6
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
blank_issues_enabled: true
|
||||
version: 2.1
|
||||
contact_links:
|
||||
- name: Discord
|
||||
url: https://discord.gg/6adMQxSpJS
|
||||
about: General community discussions
|
||||
19
.github/ISSUE_TEMPLATE/documentation.yml
vendored
Normal file
19
.github/ISSUE_TEMPLATE/documentation.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Documentation
|
||||
description: Report an issue related to the LangChain documentation.
|
||||
title: "DOC: <Please write a comprehensive title after the 'DOC: ' prefix>"
|
||||
labels: [03 - Documentation]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: "Issue with current documentation:"
|
||||
description: >
|
||||
Please make sure to leave a reference to the document/code you're
|
||||
referring to.
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: "Idea or request for content:"
|
||||
description: >
|
||||
Please describe as clearly as possible what topics you think are missing
|
||||
from the current documentation.
|
||||
30
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
Normal file
30
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: "\U0001F680 Feature request"
|
||||
description: Submit a proposal/request for a new LangChain feature
|
||||
labels: ["02 Feature Request"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: feature-request
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Feature request
|
||||
description: |
|
||||
A clear and concise description of the feature proposal. Please provide links to any relevant GitHub repos, papers, or other resources if relevant.
|
||||
|
||||
- type: textarea
|
||||
id: motivation
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Motivation
|
||||
description: |
|
||||
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
|
||||
|
||||
- type: textarea
|
||||
id: contribution
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Your contribution
|
||||
description: |
|
||||
Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md)
|
||||
18
.github/ISSUE_TEMPLATE/other.yml
vendored
Normal file
18
.github/ISSUE_TEMPLATE/other.yml
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
name: Other Issue
|
||||
description: Raise an issue that wouldn't be covered by the other templates.
|
||||
title: "Issue: <Please write a comprehensive title after the 'Issue: ' prefix>"
|
||||
labels: [04 - Other]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: "Issue you'd like to raise."
|
||||
description: >
|
||||
Please describe the issue you'd like to raise as clearly as possible.
|
||||
Make sure to include any relevant links or references.
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: "Suggestion:"
|
||||
description: >
|
||||
Please outline a suggestion to improve the issue here.
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
⚡ Building applications with LLMs through composability ⚡
|
||||
|
||||
[](https://github.com/hwchase17/langchain/actions/workflows/lint.yml) [](https://github.com/hwchase17/langchain/actions/workflows/test.yml) [](https://github.com/hwchase17/langchain/actions/workflows/linkcheck.yml) [](https://pepy.tech/project/langchain) [](https://opensource.org/licenses/MIT) [](https://twitter.com/langchainai) [](https://discord.gg/6adMQxSpJS)
|
||||
[](https://github.com/hwchase17/langchain/actions/workflows/lint.yml) [](https://github.com/hwchase17/langchain/actions/workflows/test.yml) [](https://github.com/hwchase17/langchain/actions/workflows/linkcheck.yml) [](https://pepy.tech/project/langchain) [](https://opensource.org/licenses/MIT) [](https://twitter.com/langchainai) [](https://discord.gg/6adMQxSpJS) [](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/hwchase17/langchain) [](https://codespaces.new/hwchase17/langchain)
|
||||
|
||||
|
||||
Looking for the JS/TS version? Check out [LangChain.js](https://github.com/hwchase17/langchainjs).
|
||||
|
||||
|
||||
@@ -24,6 +24,10 @@ To import this vectorstore:
|
||||
from langchain.vectorstores.pgvector import PGVector
|
||||
```
|
||||
|
||||
PGVector embedding size is not autodetected. If you are using ChatGPT or any other embedding with 1536 dimensions
|
||||
default is fine. If you are going to use for example HuggingFaceEmbeddings you need to set the environment variable named `PGVECTOR_VECTOR_SIZE`
|
||||
to the needed value, In case of HuggingFaceEmbeddings is would be: `PGVECTOR_VECTOR_SIZE=768`
|
||||
|
||||
### Usage
|
||||
|
||||
For a more detailed walkthrough of the PGVector Wrapper, see [this notebook](../modules/indexes/vectorstores/examples/pgvector.ipynb)
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
@@ -21,6 +21,7 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers.base import TracerSession
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionV2
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
@@ -28,7 +29,7 @@ Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||
"openai_callback", default=None
|
||||
)
|
||||
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar(
|
||||
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
|
||||
"tracing_callback", default=None
|
||||
)
|
||||
|
||||
@@ -48,7 +49,7 @@ def tracing_enabled(
|
||||
) -> Generator[TracerSession, None, None]:
|
||||
"""Get Tracer in a context manager."""
|
||||
cb = LangChainTracer()
|
||||
session = cb.load_session(session_name)
|
||||
session = cast(TracerSession, cb.load_session(session_name))
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
tracing_callback_var.set(None)
|
||||
@@ -57,7 +58,7 @@ def tracing_enabled(
|
||||
@contextmanager
|
||||
def tracing_v2_enabled(
|
||||
session_name: str = "default",
|
||||
) -> Generator[TracerSession, None, None]:
|
||||
) -> Generator[TracerSessionV2, None, None]:
|
||||
"""Get the experimental tracer handler in a context manager."""
|
||||
# Issue a warning that this is experimental
|
||||
warnings.warn(
|
||||
|
||||
@@ -12,7 +12,9 @@ from langchain.callbacks.tracers.schemas import (
|
||||
LLMRun,
|
||||
ToolRun,
|
||||
TracerSession,
|
||||
TracerSessionBase,
|
||||
TracerSessionCreate,
|
||||
TracerSessionV2,
|
||||
)
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
@@ -27,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
|
||||
self.session: Optional[TracerSession] = None
|
||||
self.session: Optional[Union[TracerSessionV2, TracerSession]] = None
|
||||
|
||||
@staticmethod
|
||||
def _add_child_run(
|
||||
@@ -49,22 +51,31 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Persist a run."""
|
||||
|
||||
@abstractmethod
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
def _persist_session(
|
||||
self, session: TracerSessionBase
|
||||
) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Persist a tracing session."""
|
||||
|
||||
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
|
||||
def _get_session_create(
|
||||
self, name: Optional[str] = None, **kwargs: Any
|
||||
) -> TracerSessionBase:
|
||||
return TracerSessionCreate(name=name, extra=kwargs)
|
||||
|
||||
def new_session(
|
||||
self, name: Optional[str] = None, **kwargs: Any
|
||||
) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""NOT thread safe, do not call this method from multiple threads."""
|
||||
session_create = TracerSessionCreate(name=name, extra=kwargs)
|
||||
session_create = self._get_session_create(name=name, **kwargs)
|
||||
session = self._persist_session(session_create)
|
||||
self.session = session
|
||||
return session
|
||||
|
||||
@abstractmethod
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Load a tracing session and set it as the Tracer's session."""
|
||||
|
||||
@abstractmethod
|
||||
def load_default_session(self) -> TracerSession:
|
||||
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
|
||||
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
|
||||
@@ -14,21 +14,32 @@ from langchain.callbacks.tracers.schemas import (
|
||||
Run,
|
||||
ToolRun,
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
TracerSessionBase,
|
||||
TracerSessionV2,
|
||||
TracerSessionV2Create,
|
||||
)
|
||||
|
||||
|
||||
def _get_headers() -> Dict[str, Any]:
|
||||
"""Get the headers for the LangChain API."""
|
||||
headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
return headers
|
||||
|
||||
|
||||
def _get_endpoint() -> str:
|
||||
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
||||
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
||||
if os.getenv("LANGCHAIN_API_KEY"):
|
||||
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
||||
self.session = self.load_session(session_name)
|
||||
self._endpoint = _get_endpoint()
|
||||
self._headers = _get_headers()
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
@@ -48,7 +59,9 @@ class LangChainTracer(BaseTracer):
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to persist run: {e}")
|
||||
|
||||
def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession:
|
||||
def _persist_session(
|
||||
self, session_create: TracerSessionBase
|
||||
) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Persist a session."""
|
||||
try:
|
||||
r = requests.post(
|
||||
@@ -81,22 +94,89 @@ class LangChainTracer(BaseTracer):
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Load a session with the given name from the tracer."""
|
||||
return self._load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
return self._load_session("default")
|
||||
|
||||
|
||||
def _get_tenant_id() -> Optional[str]:
|
||||
"""Get the tenant ID for the LangChain API."""
|
||||
tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID")
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
endpoint = _get_endpoint()
|
||||
headers = _get_headers()
|
||||
response = requests.get(endpoint + "/tenants", headers=headers)
|
||||
response.raise_for_status()
|
||||
tenants: List[Dict[str, Any]] = response.json()
|
||||
if not tenants:
|
||||
raise ValueError(f"No tenants found for URL {endpoint}")
|
||||
return tenants[0]["id"]
|
||||
|
||||
|
||||
class LangChainTracerV2(LangChainTracer):
|
||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||
|
||||
@staticmethod
|
||||
def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
|
||||
"""Convert a run to a Run."""
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the LangChain tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._endpoint = _get_endpoint()
|
||||
self._headers = _get_headers()
|
||||
self.tenant_id = _get_tenant_id()
|
||||
|
||||
def _get_session_create(
|
||||
self, name: Optional[str] = None, **kwargs: Any
|
||||
) -> TracerSessionBase:
|
||||
return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id)
|
||||
|
||||
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
|
||||
"""Persist a session."""
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{self._endpoint}/sessions",
|
||||
data=session_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
session = TracerSessionV2(id=r.json()["id"], **session_create.dict())
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to create session, using default session: {e}")
|
||||
session = self.load_session("default")
|
||||
return session
|
||||
|
||||
def _get_default_query_params(self) -> Dict[str, Any]:
|
||||
"""Get the query params for the LangChain API."""
|
||||
return {"tenant_id": self.tenant_id}
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSessionV2:
|
||||
"""Load a session with the given name from the tracer."""
|
||||
try:
|
||||
url = f"{self._endpoint}/sessions"
|
||||
params = {"tenant_id": self.tenant_id}
|
||||
if session_name:
|
||||
params["name"] = session_name
|
||||
r = requests.get(url, headers=self._headers, params=params)
|
||||
tracer_session = TracerSessionV2(**r.json()[0])
|
||||
except Exception as e:
|
||||
session_type = "default" if not session_name else session_name
|
||||
logging.warning(
|
||||
f"Failed to load {session_type} session, using empty session: {e}"
|
||||
)
|
||||
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id)
|
||||
|
||||
self.session = tracer_session
|
||||
return tracer_session
|
||||
|
||||
def load_default_session(self) -> TracerSessionV2:
|
||||
"""Load the default tracing session and set it as the Tracer's session."""
|
||||
return self.load_session("default")
|
||||
|
||||
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
|
||||
"""Convert a run to a Run."""
|
||||
session = self.session or self.load_default_session()
|
||||
inputs: Dict[str, Any] = {}
|
||||
outputs: Optional[Dict[str, Any]] = None
|
||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
@@ -126,30 +206,30 @@ class LangChainTracerV2(LangChainTracer):
|
||||
|
||||
return Run(
|
||||
id=run.uuid,
|
||||
name=run.serialized.get("name"),
|
||||
name=run.serialized.get("name", f"{run_type}-{run.uuid}"),
|
||||
start_time=run.start_time,
|
||||
end_time=run.end_time,
|
||||
extra=run.extra,
|
||||
extra=run.extra or {},
|
||||
error=run.error,
|
||||
execution_order=run.execution_order,
|
||||
serialized=run.serialized,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
session_id=run.session_id,
|
||||
session_id=session.id,
|
||||
run_type=run_type,
|
||||
parent_run_id=run.parent_uuid,
|
||||
child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs],
|
||||
child_runs=[self._convert_run(child) for child in child_runs],
|
||||
)
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
run_create = self._convert_run(run)
|
||||
|
||||
try:
|
||||
requests.post(
|
||||
result = requests.post(
|
||||
f"{self._endpoint}/runs",
|
||||
data=run_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to persist run: {e}")
|
||||
|
||||
@@ -31,6 +31,24 @@ class TracerSession(TracerSessionBase):
|
||||
id: int
|
||||
|
||||
|
||||
class TracerSessionV2Base(TracerSessionBase):
|
||||
"""A creation class for TracerSessionV2."""
|
||||
|
||||
tenant_id: UUID
|
||||
|
||||
|
||||
class TracerSessionV2Create(TracerSessionBase):
|
||||
"""A creation class for TracerSessionV2."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TracerSessionV2(TracerSessionV2Base):
|
||||
"""TracerSession schema for the V2 API."""
|
||||
|
||||
id: UUID
|
||||
|
||||
|
||||
class BaseRun(BaseModel):
|
||||
"""Base class for Run."""
|
||||
|
||||
@@ -93,9 +111,9 @@ class Run(BaseModel):
|
||||
serialized: dict
|
||||
inputs: dict
|
||||
outputs: Optional[dict]
|
||||
session_id: int
|
||||
session_id: UUID
|
||||
parent_run_id: Optional[UUID]
|
||||
example_id: Optional[UUID]
|
||||
reference_example_id: Optional[UUID]
|
||||
run_type: RunTypeEnum
|
||||
child_runs: List[Run] = Field(default_factory=list)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Iterable, List, Optional
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import (
|
||||
CursorResult,
|
||||
MetaData,
|
||||
Table,
|
||||
create_engine,
|
||||
@@ -14,7 +13,7 @@ from sqlalchemy import (
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine import CursorResult, Engine
|
||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
|
||||
|
||||
@@ -146,7 +146,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
@property
|
||||
def is_single_input(self) -> bool:
|
||||
"""Whether the tool only accepts a single input."""
|
||||
return len(self.args) == 1
|
||||
keys = {k for k in self.args if k != "kwargs"}
|
||||
return len(keys) == 1
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
|
||||
@@ -36,7 +36,6 @@ class PythonREPLTool(BaseTool):
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.sanitize_input:
|
||||
@@ -47,7 +46,6 @@ class PythonREPLTool(BaseTool):
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("PythonReplTool does not support async")
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
@@ -13,13 +14,13 @@ from sqlalchemy.orm import Session, declarative_base, relationship
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import get_from_dict_or_env, get_from_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
ADA_TOKEN_COUNT = 1536
|
||||
PGVECTOR_VECTOR_SIZE = 1536
|
||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||
|
||||
|
||||
@@ -79,13 +80,28 @@ class EmbeddingStore(BaseModel):
|
||||
)
|
||||
collection = relationship(CollectionStore, back_populates="embeddings")
|
||||
|
||||
embedding: Vector = sqlalchemy.Column(Vector(ADA_TOKEN_COUNT))
|
||||
embedding: Vector = sqlalchemy.Column(Vector(PGVECTOR_VECTOR_SIZE))
|
||||
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
||||
|
||||
# custom_id : any user defined id
|
||||
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
|
||||
def __init__(
|
||||
self, *args: Any, vector_size: Optional[int] = None, **kwargs: Any
|
||||
) -> None:
|
||||
if "embedding" not in kwargs:
|
||||
if vector_size is None:
|
||||
vector_size = int(
|
||||
get_from_env(
|
||||
"vector_size",
|
||||
"PGVECTOR_VECTOR_SIZE",
|
||||
default=str(PGVECTOR_VECTOR_SIZE),
|
||||
)
|
||||
)
|
||||
kwargs["embedding"] = sqlalchemy.Column(Vector(vector_size))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class QueryResult:
|
||||
EmbeddingStore: EmbeddingStore
|
||||
|
||||
@@ -13,7 +13,7 @@ langchain-server = "langchain.server:main"
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
pydantic = "^1"
|
||||
SQLAlchemy = ">=1.3,<3"
|
||||
SQLAlchemy = ">=1.4,<3"
|
||||
requests = "^2"
|
||||
PyYAML = ">=5.4.1"
|
||||
numpy = "^1"
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
"""Test tool utils."""
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Type, Union
|
||||
from typing import Any, Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.agents.chat.base import ChatAgent
|
||||
@@ -15,383 +12,6 @@ from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
"""Test functionality with unnamed decorator."""
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
assert search_api.name == "search_api"
|
||||
assert not search_api.return_direct
|
||||
assert search_api("test") == "API result"
|
||||
|
||||
|
||||
class _MockSchema(BaseModel):
|
||||
arg1: int
|
||||
arg2: bool
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockStructuredTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[BaseModel] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_structured_args() -> None:
|
||||
"""Test functionality with structured arguments."""
|
||||
structured_api = _MockStructuredTool()
|
||||
assert isinstance(structured_api, BaseTool)
|
||||
assert structured_api.name == "structured_api"
|
||||
expected_result = "1 True {'foo': 'bar'}"
|
||||
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
|
||||
assert structured_api.run(args) == expected_result
|
||||
|
||||
|
||||
def test_unannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool without type hints raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _UnAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_misannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _MisAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema: BaseModel = _MockSchema # type: ignore
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_forward_ref_annotated_base_tool_accepted() -> None:
|
||||
"""Test that a using forward ref annotation syntax is accepted.""" ""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: "Type[BaseModel]" = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_subclass_annotated_base_tool_accepted() -> None:
|
||||
"""Test BaseTool child w/ custom schema isn't overwritten."""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[_MockSchema] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
|
||||
tool = _ForwardRefAnnotatedTool()
|
||||
assert tool.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorator_with_specified_schema() -> None:
|
||||
"""Test that manually specified schemata are passed through to the tool."""
|
||||
|
||||
@tool(args_schema=_MockSchema)
|
||||
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func, BaseTool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
"""Test that a BaseTool without a schema meets expectations."""
|
||||
|
||||
@tool
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(structured_tool_input, BaseTool)
|
||||
assert structured_tool_input.args_schema is not None
|
||||
assert (
|
||||
structured_tool_input.args_schema.schema()["properties"]
|
||||
== _MockSchema.schema()["properties"]
|
||||
== structured_tool_input.args
|
||||
)
|
||||
|
||||
|
||||
def test_structured_args_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1}, {arg2}, {opt_arg}"
|
||||
|
||||
assert isinstance(structured_tool_input, BaseTool)
|
||||
assert structured_tool_input.name == "structured_tool_input"
|
||||
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
|
||||
expected_result = "1, 0.001, {'foo': 'bar'}"
|
||||
with pytest.raises(ValueError):
|
||||
assert structured_tool_input.run(args) == expected_result
|
||||
|
||||
|
||||
def test_structured_single_str_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def unstructured_tool_input(tool_input: str) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{tool_input}"
|
||||
|
||||
assert isinstance(unstructured_tool_input, BaseTool)
|
||||
assert unstructured_tool_input.args_schema is None
|
||||
assert unstructured_tool_input.run("foo") == "foo"
|
||||
|
||||
|
||||
def test_base_tool_inheritance_base_schema() -> None:
|
||||
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
||||
|
||||
class _MockSimpleTool(BaseTool):
|
||||
name = "simple_tool"
|
||||
description = "A Simple Tool"
|
||||
|
||||
def _run(self, tool_input: str) -> str:
|
||||
return f"{tool_input}"
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
simple_tool = _MockSimpleTool()
|
||||
assert simple_tool.args_schema is None
|
||||
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
|
||||
assert simple_tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_lambda_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
|
||||
tool = Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=lambda tool_input: tool_input,
|
||||
)
|
||||
assert tool.args_schema is None
|
||||
expected_args = {"tool_input": {"type": "string"}}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_structured_tool_lambda_multi_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
tool = StructuredTool.from_function(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
||||
)
|
||||
assert tool.args_schema is not None
|
||||
expected_args = {
|
||||
"tool_input": {"title": "Tool Input"},
|
||||
"other_arg": {"title": "Other Arg"},
|
||||
}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_partial_function_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a partial function."""
|
||||
|
||||
def func(tool_input: str, other_arg: str) -> str:
|
||||
return tool_input + other_arg
|
||||
|
||||
tool = Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=partial(func, other_arg="foo"),
|
||||
)
|
||||
assert tool.run("bar") == "barfoo"
|
||||
|
||||
|
||||
def test_empty_args_decorator() -> None:
|
||||
"""Test inferred schema of decorated fn with no args."""
|
||||
|
||||
@tool
|
||||
def empty_tool_input() -> str:
|
||||
"""Return a constant."""
|
||||
return "the empty result"
|
||||
|
||||
assert isinstance(empty_tool_input, BaseTool)
|
||||
assert empty_tool_input.name == "empty_tool_input"
|
||||
assert empty_tool_input.args == {}
|
||||
assert empty_tool_input.run({}) == "the empty result"
|
||||
|
||||
|
||||
def test_named_tool_decorator() -> None:
|
||||
"""Test functionality when arguments are provided as input to decorator."""
|
||||
|
||||
@tool("search")
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
assert search_api.name == "search"
|
||||
assert not search_api.return_direct
|
||||
|
||||
|
||||
def test_named_tool_decorator_return_direct() -> None:
|
||||
"""Test functionality when arguments and return direct are provided as input."""
|
||||
|
||||
@tool("search", return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
assert search_api.name == "search"
|
||||
assert search_api.return_direct
|
||||
|
||||
|
||||
def test_unnamed_tool_decorator_return_direct() -> None:
|
||||
"""Test functionality when only return direct is provided."""
|
||||
|
||||
@tool(return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
assert search_api.name == "search_api"
|
||||
assert search_api.return_direct
|
||||
|
||||
|
||||
def test_tool_with_kwargs() -> None:
|
||||
"""Test functionality when only return direct is provided."""
|
||||
|
||||
@tool(return_direct=True)
|
||||
def search_api(
|
||||
arg_0: str,
|
||||
arg_1: float = 4.3,
|
||||
ping: str = "hi",
|
||||
) -> str:
|
||||
"""Search the API for the query."""
|
||||
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_0": "foo",
|
||||
"arg_1": 3.2,
|
||||
"ping": "pong",
|
||||
}
|
||||
)
|
||||
assert result == "arg_0=foo, arg_1=3.2, ping=pong"
|
||||
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_0": "foo",
|
||||
}
|
||||
)
|
||||
assert result == "arg_0=foo, arg_1=4.3, ping=hi"
|
||||
# For backwards compatibility, we still accept a single str arg
|
||||
result = search_api.run("foobar")
|
||||
assert result == "arg_0=foobar, arg_1=4.3, ping=hi"
|
||||
|
||||
|
||||
def test_missing_docstring() -> None:
|
||||
"""Test error is raised when docstring is missing."""
|
||||
# expect to throw a value error if theres no docstring
|
||||
with pytest.raises(AssertionError, match="Function must have a docstring"):
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
return "API result"
|
||||
|
||||
|
||||
def test_create_tool_positional_args() -> None:
|
||||
"""Test that positional arguments are allowed."""
|
||||
test_tool = Tool("test_name", lambda x: x, "test_description")
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.is_single_input
|
||||
|
||||
|
||||
def test_create_tool_keyword_args() -> None:
|
||||
"""Test that keyword arguments are allowed."""
|
||||
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
||||
assert test_tool.is_single_input
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_async_tool() -> None:
|
||||
"""Test that async tools are allowed."""
|
||||
|
||||
async def _test_func(x: str) -> str:
|
||||
return x
|
||||
|
||||
test_tool = Tool(
|
||||
name="test_name",
|
||||
func=lambda x: x,
|
||||
description="test_description",
|
||||
coroutine=_test_func,
|
||||
)
|
||||
assert test_tool.is_single_input
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.coroutine is not None
|
||||
assert await test_tool.arun("foo") == "foo"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Union
|
||||
from uuid import uuid4
|
||||
from typing import List, Tuple, Union
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
@@ -16,7 +17,8 @@ from langchain.callbacks.tracers.base import (
|
||||
TracerException,
|
||||
TracerSession,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import TracerSessionCreate
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
||||
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
@@ -27,7 +29,7 @@ def load_session(session_name: str) -> TracerSession:
|
||||
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
|
||||
|
||||
|
||||
def _persist_session(session: TracerSessionCreate) -> TracerSession:
|
||||
def _persist_session(session: TracerSessionBase) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return TracerSession(id=TEST_SESSION_ID, **session.dict())
|
||||
|
||||
@@ -49,7 +51,7 @@ class FakeTracer(BaseTracer):
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
def _persist_session(self, session: TracerSessionBase) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
@@ -473,3 +475,125 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
)
|
||||
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
|
||||
|
||||
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
|
||||
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2:
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
||||
tracer = LangChainTracerV2()
|
||||
return tracer
|
||||
|
||||
|
||||
# Mock a sample TracerSessionV2 object
|
||||
@pytest.fixture
|
||||
def sample_tracer_session_v2() -> TracerSessionV2:
|
||||
return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
||||
|
||||
|
||||
# Mock a sample LLMRun, ChainRun, and ToolRun objects
|
||||
@pytest.fixture
|
||||
def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
|
||||
llm_run = LLMRun(
|
||||
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||
name="llm_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
session_id=1,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
serialized={},
|
||||
extra={},
|
||||
)
|
||||
chain_run = ChainRun(
|
||||
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
name="chain_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
session_id=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_llm_runs=[llm_run],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
extra={},
|
||||
)
|
||||
tool_run = ToolRun(
|
||||
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||
name="tool_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
session_id=1,
|
||||
tool_input="test",
|
||||
action="{}",
|
||||
serialized={},
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
extra={},
|
||||
)
|
||||
return llm_run, chain_run, tool_run
|
||||
|
||||
|
||||
# Test _get_default_query_params method
|
||||
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
|
||||
expected = {"tenant_id": "test-tenant-id"}
|
||||
result = lang_chain_tracer_v2._get_default_query_params()
|
||||
assert result == expected
|
||||
|
||||
|
||||
# Test load_session method
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.get")
|
||||
def test_load_session(
|
||||
mock_requests_get: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
) -> None:
|
||||
"""Test that load_session method returns a TracerSessionV2 object."""
|
||||
mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()]
|
||||
result = lang_chain_tracer_v2.load_session("test-session-name")
|
||||
mock_requests_get.assert_called_with(
|
||||
"http://test-endpoint.com/sessions",
|
||||
headers={"Content-Type": "application/json", "x-api-key": "foo"},
|
||||
params={"tenant_id": "test-tenant-id", "name": "test-session-name"},
|
||||
)
|
||||
assert result == sample_tracer_session_v2
|
||||
|
||||
|
||||
def test_convert_run(
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run)
|
||||
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
|
||||
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
|
||||
|
||||
assert isinstance(converted_llm_run, Run)
|
||||
assert isinstance(converted_chain_run, Run)
|
||||
assert isinstance(converted_tool_run, Run)
|
||||
|
||||
|
||||
@patch("langchain.callbacks.tracers.langchain.requests.post")
|
||||
def test_persist_run(
|
||||
mock_requests_post: Mock,
|
||||
lang_chain_tracer_v2: LangChainTracerV2,
|
||||
sample_tracer_session_v2: TracerSessionV2,
|
||||
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
mock_requests_post.return_value.raise_for_status.return_value = None
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
lang_chain_tracer_v2._persist_run(llm_run)
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
lang_chain_tracer_v2._persist_run(tool_run)
|
||||
|
||||
assert mock_requests_post.call_count == 3
|
||||
|
||||
0
tests/unit_tests/tools/python/__init__.py
Normal file
0
tests/unit_tests/tools/python/__init__.py
Normal file
23
tests/unit_tests/tools/python/test_python.py
Normal file
23
tests/unit_tests/tools/python/test_python.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Test Python REPL Tools."""
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
|
||||
|
||||
|
||||
def test_python_repl_tool_single_input() -> None:
|
||||
"""Test that the python REPL tool works with a single input."""
|
||||
tool = PythonREPLTool()
|
||||
assert tool.is_single_input
|
||||
assert int(tool.run("print(1 + 1)").strip()) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_tool_single_input() -> None:
|
||||
"""Test that the python REPL tool works with a single input."""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.is_single_input
|
||||
assert tool.run("1 + 1") == 2
|
||||
438
tests/unit_tests/tools/test_base.py
Normal file
438
tests/unit_tests/tools/test_base.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""Test the base tool implementation."""
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Type, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents.tools import Tool, tool
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool
|
||||
|
||||
|
||||
def test_unnamed_decorator() -> None:
|
||||
"""Test functionality with unnamed decorator."""
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
assert search_api.name == "search_api"
|
||||
assert not search_api.return_direct
|
||||
assert search_api("test") == "API result"
|
||||
|
||||
|
||||
class _MockSchema(BaseModel):
|
||||
arg1: int
|
||||
arg2: bool
|
||||
arg3: Optional[dict] = None
|
||||
|
||||
|
||||
class _MockStructuredTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[BaseModel] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_structured_args() -> None:
|
||||
"""Test functionality with structured arguments."""
|
||||
structured_api = _MockStructuredTool()
|
||||
assert isinstance(structured_api, BaseTool)
|
||||
assert structured_api.name == "structured_api"
|
||||
expected_result = "1 True {'foo': 'bar'}"
|
||||
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
|
||||
assert structured_api.run(args) == expected_result
|
||||
|
||||
|
||||
def test_unannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool without type hints raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _UnAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_misannotated_base_tool_raises_error() -> None:
|
||||
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
|
||||
with pytest.raises(SchemaAnnotationError):
|
||||
|
||||
class _MisAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
# This would silently be ignored without the custom metaclass
|
||||
args_schema: BaseModel = _MockSchema # type: ignore
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_forward_ref_annotated_base_tool_accepted() -> None:
|
||||
"""Test that a using forward ref annotation syntax is accepted.""" ""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: "Type[BaseModel]" = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_subclass_annotated_base_tool_accepted() -> None:
|
||||
"""Test BaseTool child w/ custom schema isn't overwritten."""
|
||||
|
||||
class _ForwardRefAnnotatedTool(BaseTool):
|
||||
name = "structured_api"
|
||||
args_schema: Type[_MockSchema] = _MockSchema
|
||||
description = "A Structured Tool"
|
||||
|
||||
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
async def _arun(
|
||||
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
|
||||
tool = _ForwardRefAnnotatedTool()
|
||||
assert tool.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorator_with_specified_schema() -> None:
|
||||
"""Test that manually specified schemata are passed through to the tool."""
|
||||
|
||||
@tool(args_schema=_MockSchema)
|
||||
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func, BaseTool)
|
||||
assert tool_func.args_schema == _MockSchema
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
"""Test that a BaseTool without a schema meets expectations."""
|
||||
|
||||
@tool
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(structured_tool_input, BaseTool)
|
||||
assert structured_tool_input.args_schema is not None
|
||||
assert (
|
||||
structured_tool_input.args_schema.schema()["properties"]
|
||||
== _MockSchema.schema()["properties"]
|
||||
== structured_tool_input.args
|
||||
)
|
||||
|
||||
|
||||
def test_args_kwargs_filtered() -> None:
|
||||
class _SingleArgToolWithKwargs(BaseTool):
|
||||
name = "single_arg_tool"
|
||||
description = "A single arged tool with kwargs"
|
||||
|
||||
def _run(
|
||||
self,
|
||||
some_arg: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return "foo"
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
some_arg: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
tool = _SingleArgToolWithKwargs()
|
||||
assert tool.is_single_input
|
||||
|
||||
class _VarArgToolWithKwargs(BaseTool):
|
||||
name = "single_arg_tool"
|
||||
description = "A single arged tool with kwargs"
|
||||
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return "foo"
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
tool2 = _VarArgToolWithKwargs()
|
||||
assert tool2.is_single_input
|
||||
|
||||
|
||||
def test_structured_args_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
||||
) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{arg1}, {arg2}, {opt_arg}"
|
||||
|
||||
assert isinstance(structured_tool_input, BaseTool)
|
||||
assert structured_tool_input.name == "structured_tool_input"
|
||||
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
|
||||
expected_result = "1, 0.001, {'foo': 'bar'}"
|
||||
with pytest.raises(ValueError):
|
||||
assert structured_tool_input.run(args) == expected_result
|
||||
|
||||
|
||||
def test_structured_single_str_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool(infer_schema=False)
|
||||
def unstructured_tool_input(tool_input: str) -> str:
|
||||
"""Return the arguments directly."""
|
||||
return f"{tool_input}"
|
||||
|
||||
assert isinstance(unstructured_tool_input, BaseTool)
|
||||
assert unstructured_tool_input.args_schema is None
|
||||
assert unstructured_tool_input.run("foo") == "foo"
|
||||
|
||||
|
||||
def test_base_tool_inheritance_base_schema() -> None:
|
||||
"""Test schema is correctly inferred when inheriting from BaseTool."""
|
||||
|
||||
class _MockSimpleTool(BaseTool):
|
||||
name = "simple_tool"
|
||||
description = "A Simple Tool"
|
||||
|
||||
def _run(self, tool_input: str) -> str:
|
||||
return f"{tool_input}"
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
simple_tool = _MockSimpleTool()
|
||||
assert simple_tool.args_schema is None
|
||||
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
|
||||
assert simple_tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_lambda_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
|
||||
tool = Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=lambda tool_input: tool_input,
|
||||
)
|
||||
assert tool.args_schema is None
|
||||
expected_args = {"tool_input": {"type": "string"}}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_structured_tool_lambda_multi_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a lambda function."""
|
||||
tool = StructuredTool.from_function(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
|
||||
)
|
||||
assert tool.args_schema is not None
|
||||
expected_args = {
|
||||
"tool_input": {"title": "Tool Input"},
|
||||
"other_arg": {"title": "Other Arg"},
|
||||
}
|
||||
assert tool.args == expected_args
|
||||
|
||||
|
||||
def test_tool_partial_function_args_schema() -> None:
|
||||
"""Test args schema inference when the tool argument is a partial function."""
|
||||
|
||||
def func(tool_input: str, other_arg: str) -> str:
|
||||
return tool_input + other_arg
|
||||
|
||||
tool = Tool(
|
||||
name="tool",
|
||||
description="A tool",
|
||||
func=partial(func, other_arg="foo"),
|
||||
)
|
||||
assert tool.run("bar") == "barfoo"
|
||||
|
||||
|
||||
def test_empty_args_decorator() -> None:
|
||||
"""Test inferred schema of decorated fn with no args."""
|
||||
|
||||
@tool
|
||||
def empty_tool_input() -> str:
|
||||
"""Return a constant."""
|
||||
return "the empty result"
|
||||
|
||||
assert isinstance(empty_tool_input, BaseTool)
|
||||
assert empty_tool_input.name == "empty_tool_input"
|
||||
assert empty_tool_input.args == {}
|
||||
assert empty_tool_input.run({}) == "the empty result"
|
||||
|
||||
|
||||
def test_named_tool_decorator() -> None:
|
||||
"""Test functionality when arguments are provided as input to decorator."""
|
||||
|
||||
@tool("search")
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
assert search_api.name == "search"
|
||||
assert not search_api.return_direct
|
||||
|
||||
|
||||
def test_named_tool_decorator_return_direct() -> None:
|
||||
"""Test functionality when arguments and return direct are provided as input."""
|
||||
|
||||
@tool("search", return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
assert search_api.name == "search"
|
||||
assert search_api.return_direct
|
||||
|
||||
|
||||
def test_unnamed_tool_decorator_return_direct() -> None:
|
||||
"""Test functionality when only return direct is provided."""
|
||||
|
||||
@tool(return_direct=True)
|
||||
def search_api(query: str) -> str:
|
||||
"""Search the API for the query."""
|
||||
return "API result"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
assert search_api.name == "search_api"
|
||||
assert search_api.return_direct
|
||||
|
||||
|
||||
def test_tool_with_kwargs() -> None:
|
||||
"""Test functionality when only return direct is provided."""
|
||||
|
||||
@tool(return_direct=True)
|
||||
def search_api(
|
||||
arg_0: str,
|
||||
arg_1: float = 4.3,
|
||||
ping: str = "hi",
|
||||
) -> str:
|
||||
"""Search the API for the query."""
|
||||
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
|
||||
|
||||
assert isinstance(search_api, BaseTool)
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_0": "foo",
|
||||
"arg_1": 3.2,
|
||||
"ping": "pong",
|
||||
}
|
||||
)
|
||||
assert result == "arg_0=foo, arg_1=3.2, ping=pong"
|
||||
|
||||
result = search_api.run(
|
||||
tool_input={
|
||||
"arg_0": "foo",
|
||||
}
|
||||
)
|
||||
assert result == "arg_0=foo, arg_1=4.3, ping=hi"
|
||||
# For backwards compatibility, we still accept a single str arg
|
||||
result = search_api.run("foobar")
|
||||
assert result == "arg_0=foobar, arg_1=4.3, ping=hi"
|
||||
|
||||
|
||||
def test_missing_docstring() -> None:
|
||||
"""Test error is raised when docstring is missing."""
|
||||
# expect to throw a value error if theres no docstring
|
||||
with pytest.raises(AssertionError, match="Function must have a docstring"):
|
||||
|
||||
@tool
|
||||
def search_api(query: str) -> str:
|
||||
return "API result"
|
||||
|
||||
|
||||
def test_create_tool_positional_args() -> None:
|
||||
"""Test that positional arguments are allowed."""
|
||||
test_tool = Tool("test_name", lambda x: x, "test_description")
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.is_single_input
|
||||
|
||||
|
||||
def test_create_tool_keyword_args() -> None:
|
||||
"""Test that keyword arguments are allowed."""
|
||||
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
|
||||
assert test_tool.is_single_input
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_async_tool() -> None:
|
||||
"""Test that async tools are allowed."""
|
||||
|
||||
async def _test_func(x: str) -> str:
|
||||
return x
|
||||
|
||||
test_tool = Tool(
|
||||
name="test_name",
|
||||
func=lambda x: x,
|
||||
description="test_description",
|
||||
coroutine=_test_func,
|
||||
)
|
||||
assert test_tool.is_single_input
|
||||
assert test_tool("foo") == "foo"
|
||||
assert test_tool.name == "test_name"
|
||||
assert test_tool.description == "test_description"
|
||||
assert test_tool.coroutine is not None
|
||||
assert await test_tool.arun("foo") == "foo"
|
||||
@@ -70,7 +70,7 @@ def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None:
|
||||
assert file_contents is None
|
||||
file_contents = Path(file_path).read_text()
|
||||
|
||||
mocked_responses.get(
|
||||
mocked_responses.get( # type: ignore
|
||||
urljoin(URL_BASE.format(ref=ref), path),
|
||||
body=body,
|
||||
status=200,
|
||||
@@ -86,7 +86,9 @@ def test_failed_request(mocked_responses: responses.RequestsMock) -> None:
|
||||
path = "chains/path/chain.json"
|
||||
loader = Mock()
|
||||
|
||||
mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500)
|
||||
mocked_responses.get( # type: ignore
|
||||
urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=re.compile("Could not find file at .*")):
|
||||
try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})
|
||||
|
||||
22
tests/unit_tests/vectorstores/test_pgvector.py
Normal file
22
tests/unit_tests/vectorstores/test_pgvector.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import os
|
||||
|
||||
from langchain.vectorstores.pgvector import PGVECTOR_VECTOR_SIZE, EmbeddingStore
|
||||
|
||||
|
||||
def test_embedding_store_init_defaults() -> None:
|
||||
expected = PGVECTOR_VECTOR_SIZE
|
||||
actual = EmbeddingStore().embedding.type.dim
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_embedding_store_init_vector_size() -> None:
|
||||
expected = 2
|
||||
actual = EmbeddingStore(vector_size=2).embedding.type.dim
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_embedding_store_init_env_vector_size() -> None:
|
||||
os.environ["PGVECTOR_VECTOR_SIZE"] = "3"
|
||||
expected = 3
|
||||
actual = EmbeddingStore().embedding.type.dim
|
||||
assert expected == actual
|
||||
Reference in New Issue
Block a user