Compare commits

...

7 Commits

Author SHA1 Message Date
Dev 2049
64d1e2042e cr 2023-05-05 16:33:32 -07:00
Martin Holzhauer
0a7245f721 env variable based fix for non chatgpt embeddings (#3964)
this is a simple fix for #2219 

i also added some documentation for this environment variable
2023-05-05 16:31:38 -07:00
Zander Chase
6032a051e9 Add Tenant ID to V2 Tracer (#4135)
Update the V2 tracer to
- use UUIDs instead of int's
- load a tenant ID and use that when saving sessions
2023-05-04 21:35:20 -07:00
Zander Chase
fea639c1fc Vwp/sqlalchemy (#4145)
Bump threshold to 1.4 from 1.3. Change import to be compatible

Resolves #4142 and #4129

---------

Co-authored-by: ndaugreal <ndaugreal@gmail.com>
Co-authored-by: Jeremy Lopez <lopez86@users.noreply.github.com>
2023-05-04 20:46:38 -07:00
Zander Chase
2f087d63af Fix Python RePL Tool (#4137)
Filter out kwargs from inferred schema when determining if a tool is
single input.

Add a couple unit tests.

Move tool unit tests to the tools dir
2023-05-04 20:31:16 -07:00
Zander Chase
cc068f1b77 Add Issue Templates (#4021)
Add issue templates for
- bug reports
- feature suggestions
- documentation
and a link to the discord for general discussion.

Open to other suggestions here. Could also add another "Other" template
with just a raw text box if we think this is too restrictive


<img width="1464" alt="image"
src="https://user-images.githubusercontent.com/130414180/236115358-e603bcbe-282c-40c7-82eb-905eb93ccec0.png">
2023-05-04 16:33:52 -07:00
Zander Chase
ac0a9d02bd Visual Studio Code/Github Codespaces Dev Containers (#4035) (#4122)
Having dev containers makes its easier, faster and secure to setup the
dev environment for the repository.

The pull request consists of:

- .devcontainer folder with:
- **devcontainer.json :** (minimal necessary vscode extensions and
settings)
- **docker-compose.yaml :** (could be modified to run necessary services
as per need. Ex vectordbs, databases)
    - **Dockerfile:**(non root with dev tools)
- Changes to README - added the Open in Github Codespaces Badge - added
the Open in dev container Badge

Co-authored-by: Jinto Jose <129657162+jj701@users.noreply.github.com>
2023-05-04 11:37:00 -07:00
26 changed files with 1072 additions and 429 deletions

42
.devcontainer/Dockerfile Normal file
View File

@@ -0,0 +1,42 @@
# This is a Dockerfile for Developer Container
# Use the Python base image
ARG VARIANT="3.11-bullseye"
FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT} AS langchain-dev-base
USER vscode
# Define the version of Poetry to install (default is 1.4.2)
# Define the directory of python virtual environment
ARG PYTHON_VIRTUALENV_HOME=/home/vscode/langchain-py-env \
POETRY_VERSION=1.4.2
ENV POETRY_VIRTUALENVS_IN_PROJECT=false \
POETRY_NO_INTERACTION=true
# Create a Python virtual environment for Poetry and install it
RUN python3 -m venv ${PYTHON_VIRTUALENV_HOME} && \
$PYTHON_VIRTUALENV_HOME/bin/pip install --upgrade pip && \
$PYTHON_VIRTUALENV_HOME/bin/pip install poetry==${POETRY_VERSION}
ENV PATH="$PYTHON_VIRTUALENV_HOME/bin:$PATH" \
VIRTUAL_ENV=$PYTHON_VIRTUALENV_HOME
# Setup for bash
RUN poetry completions bash >> /home/vscode/.bash_completion && \
echo "export PATH=$PYTHON_VIRTUALENV_HOME/bin:$PATH" >> ~/.bashrc
# Set the working directory for the app
WORKDIR /workspaces/langchain
# Use a multi-stage build to install dependencies
FROM langchain-dev-base AS langchain-dev-dependencies
ARG PYTHON_VIRTUALENV_HOME
# Copy only the dependency files for installation
COPY pyproject.toml poetry.lock poetry.toml ./
# Install the Poetry dependencies (this layer will be cached as long as the dependencies don't change)
RUN poetry install --no-interaction --no-ansi --with dev,test,docs

View File

@@ -0,0 +1,33 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile
{
"dockerComposeFile": "./docker-compose.yaml",
"service": "langchain",
"workspaceFolder": "/workspaces/langchain",
"name": "langchain",
"customizations": {
"vscode": {
"extensions": [
"ms-python.python"
],
"settings": {
"python.defaultInterpreterPath": "/home/vscode/langchain-py-env/bin/python3.11"
}
}
},
// Features to add to the dev container. More info: https://containers.dev/features.
"features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Uncomment the next line to run commands after the container is created.
// "postCreateCommand": "cat /etc/os-release",
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "devcontainer"
"remoteUser": "vscode",
"overrideCommand": true
}

View File

@@ -0,0 +1,31 @@
version: '3'
services:
langchain:
build:
dockerfile: .devcontainer/Dockerfile
context: ../
volumes:
- ../:/workspaces/langchain
networks:
- langchain-network
# environment:
# MONGO_ROOT_USERNAME: root
# MONGO_ROOT_PASSWORD: example123
# depends_on:
# - mongo
# mongo:
# image: mongo
# restart: unless-stopped
# environment:
# MONGO_INITDB_ROOT_USERNAME: root
# MONGO_INITDB_ROOT_PASSWORD: example123
# ports:
# - "27017:27017"
# networks:
# - langchain-network
networks:
langchain-network:
driver: bridge

106
.github/ISSUE_TEMPLATE/bug-report.yml vendored Normal file
View File

@@ -0,0 +1,106 @@
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve LangChain
labels: ["02 Bug Report"]
body:
- type: markdown
attributes:
value: >
Thank you for taking the time to file a bug report. Before creating a new
issue, please make sure to take a few moments to check the issue tracker
for existing issues about the bug.
- type: textarea
id: system-info
attributes:
label: System Info
description: Please share your system info with us.
placeholder: LangChain version, platform, python version, ...
validations:
required: true
- type: textarea
id: who-can-help
attributes:
label: Who can help?
description: |
Your issue will be replied to more quickly if you can figure out the right person to tag with @
If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
The core maintainers strive to read all issues, but tagging them will help them prioritize.
Please tag fewer than 3 people.
@hwchase17 - project lead
Tracing / Callbacks
- @agola11
Async
- @agola11
DataLoader Abstractions
- @eyurtsev
LLM/Chat Wrappers
- @hwchase17
- @agola11
Tools / Toolkits
- @vowelparrot
placeholder: "@Username ..."
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: "The problem arises when using:"
options:
- label: "The official example notebooks/scripts"
- label: "My own modified scripts"
- type: checkboxes
id: related-components
attributes:
label: Related Components
description: "Select the components related to the issue (if applicable):"
options:
- label: "LLMs/Chat Models"
- label: "Embedding Models"
- label: "Prompts / Prompt Templates / Prompt Selectors"
- label: "Output Parsers"
- label: "Document Loaders"
- label: "Vector Stores / Retrievers"
- label: "Memory"
- label: "Agents / Agent Executors"
- label: "Tools / Toolkits"
- label: "Chains"
- label: "Callbacks/Tracing"
- label: "Async"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide a [code sample](https://stackoverflow.com/help/minimal-reproducible-example) that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
If you have code snippets, error messages, stack traces please provide them here as well.
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Avoid screenshots when possible, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
placeholder: |
Steps to reproduce the behavior:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."

6
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,6 @@
blank_issues_enabled: true
version: 2.1
contact_links:
- name: Discord
url: https://discord.gg/6adMQxSpJS
about: General community discussions

View File

@@ -0,0 +1,19 @@
name: Documentation
description: Report an issue related to the LangChain documentation.
title: "DOC: <Please write a comprehensive title after the 'DOC: ' prefix>"
labels: [03 - Documentation]
body:
- type: textarea
attributes:
label: "Issue with current documentation:"
description: >
Please make sure to leave a reference to the document/code you're
referring to.
- type: textarea
attributes:
label: "Idea or request for content:"
description: >
Please describe as clearly as possible what topics you think are missing
from the current documentation.

View File

@@ -0,0 +1,30 @@
name: "\U0001F680 Feature request"
description: Submit a proposal/request for a new LangChain feature
labels: ["02 Feature Request"]
body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request
description: |
A clear and concise description of the feature proposal. Please provide links to any relevant GitHub repos, papers, or other resources if relevant.
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation
description: |
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
- type: textarea
id: contribution
validations:
required: true
attributes:
label: Your contribution
description: |
Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md)

18
.github/ISSUE_TEMPLATE/other.yml vendored Normal file
View File

@@ -0,0 +1,18 @@
name: Other Issue
description: Raise an issue that wouldn't be covered by the other templates.
title: "Issue: <Please write a comprehensive title after the 'Issue: ' prefix>"
labels: [04 - Other]
body:
- type: textarea
attributes:
label: "Issue you'd like to raise."
description: >
Please describe the issue you'd like to raise as clearly as possible.
Make sure to include any relevant links or references.
- type: textarea
attributes:
label: "Suggestion:"
description: >
Please outline a suggestion to improve the issue here.

View File

@@ -2,7 +2,8 @@
⚡ Building applications with LLMs through composability ⚡
[![lint](https://github.com/hwchase17/langchain/actions/workflows/lint.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/lint.yml) [![test](https://github.com/hwchase17/langchain/actions/workflows/test.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/test.yml) [![linkcheck](https://github.com/hwchase17/langchain/actions/workflows/linkcheck.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/linkcheck.yml) [![Downloads](https://static.pepy.tech/badge/langchain/month)](https://pepy.tech/project/langchain) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/langchainai.svg?style=social&label=Follow%20%40LangChainAI)](https://twitter.com/langchainai) [![](https://dcbadge.vercel.app/api/server/6adMQxSpJS?compact=true&style=flat)](https://discord.gg/6adMQxSpJS)
[![lint](https://github.com/hwchase17/langchain/actions/workflows/lint.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/lint.yml) [![test](https://github.com/hwchase17/langchain/actions/workflows/test.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/test.yml) [![linkcheck](https://github.com/hwchase17/langchain/actions/workflows/linkcheck.yml/badge.svg)](https://github.com/hwchase17/langchain/actions/workflows/linkcheck.yml) [![Downloads](https://static.pepy.tech/badge/langchain/month)](https://pepy.tech/project/langchain) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/langchainai.svg?style=social&label=Follow%20%40LangChainAI)](https://twitter.com/langchainai) [![](https://dcbadge.vercel.app/api/server/6adMQxSpJS?compact=true&style=flat)](https://discord.gg/6adMQxSpJS) [![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/hwchase17/langchain) [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/hwchase17/langchain)
Looking for the JS/TS version? Check out [LangChain.js](https://github.com/hwchase17/langchainjs).

View File

@@ -24,6 +24,10 @@ To import this vectorstore:
from langchain.vectorstores.pgvector import PGVector
```
PGVector embedding size is not autodetected. If you are using ChatGPT or any other embedding with 1536 dimensions
default is fine. If you are going to use for example HuggingFaceEmbeddings you need to set the environment variable named `PGVECTOR_VECTOR_SIZE`
to the needed value, In case of HuggingFaceEmbeddings is would be: `PGVECTOR_VECTOR_SIZE=768`
### Usage
For a more detailed walkthrough of the PGVector Wrapper, see [this notebook](../modules/indexes/vectorstores/examples/pgvector.ipynb)

View File

@@ -6,7 +6,7 @@ import os
import warnings
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast
from uuid import UUID, uuid4
from langchain.callbacks.base import (
@@ -21,6 +21,7 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
from langchain.callbacks.tracers.schemas import TracerSessionV2
from langchain.schema import AgentAction, AgentFinish, LLMResult
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
@@ -28,7 +29,7 @@ Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
)
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar(
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
"tracing_callback", default=None
)
@@ -48,7 +49,7 @@ def tracing_enabled(
) -> Generator[TracerSession, None, None]:
"""Get Tracer in a context manager."""
cb = LangChainTracer()
session = cb.load_session(session_name)
session = cast(TracerSession, cb.load_session(session_name))
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
@@ -57,7 +58,7 @@ def tracing_enabled(
@contextmanager
def tracing_v2_enabled(
session_name: str = "default",
) -> Generator[TracerSession, None, None]:
) -> Generator[TracerSessionV2, None, None]:
"""Get the experimental tracer handler in a context manager."""
# Issue a warning that this is experimental
warnings.warn(

View File

@@ -12,7 +12,9 @@ from langchain.callbacks.tracers.schemas import (
LLMRun,
ToolRun,
TracerSession,
TracerSessionBase,
TracerSessionCreate,
TracerSessionV2,
)
from langchain.schema import LLMResult
@@ -27,7 +29,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
self.session: Optional[TracerSession] = None
self.session: Optional[Union[TracerSessionV2, TracerSession]] = None
@staticmethod
def _add_child_run(
@@ -49,22 +51,31 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Persist a run."""
@abstractmethod
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
def _persist_session(
self, session: TracerSessionBase
) -> Union[TracerSession, TracerSessionV2]:
"""Persist a tracing session."""
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any
) -> TracerSessionBase:
return TracerSessionCreate(name=name, extra=kwargs)
def new_session(
self, name: Optional[str] = None, **kwargs: Any
) -> Union[TracerSession, TracerSessionV2]:
"""NOT thread safe, do not call this method from multiple threads."""
session_create = TracerSessionCreate(name=name, extra=kwargs)
session_create = self._get_session_create(name=name, **kwargs)
session = self._persist_session(session_create)
self.session = session
return session
@abstractmethod
def load_session(self, session_name: str) -> TracerSession:
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
"""Load a tracing session and set it as the Tracer's session."""
@abstractmethod
def load_default_session(self) -> TracerSession:
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
"""Load the default tracing session and set it as the Tracer's session."""
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:

View File

@@ -14,21 +14,32 @@ from langchain.callbacks.tracers.schemas import (
Run,
ToolRun,
TracerSession,
TracerSessionCreate,
TracerSessionBase,
TracerSessionV2,
TracerSessionV2Create,
)
def _get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
return headers
def _get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
self.session = self.load_session(session_name)
self._endpoint = _get_endpoint()
self._headers = _get_headers()
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
@@ -48,7 +59,9 @@ class LangChainTracer(BaseTracer):
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession:
def _persist_session(
self, session_create: TracerSessionBase
) -> Union[TracerSession, TracerSessionV2]:
"""Persist a session."""
try:
r = requests.post(
@@ -81,22 +94,89 @@ class LangChainTracer(BaseTracer):
self.session = tracer_session
return tracer_session
def load_session(self, session_name: str) -> TracerSession:
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)
def load_default_session(self) -> TracerSession:
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
"""Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default")
def _get_tenant_id() -> Optional[str]:
"""Get the tenant ID for the LangChain API."""
tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID")
if tenant_id:
return tenant_id
endpoint = _get_endpoint()
headers = _get_headers()
response = requests.get(endpoint + "/tenants", headers=headers)
response.raise_for_status()
tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint}")
return tenants[0]["id"]
class LangChainTracerV2(LangChainTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
@staticmethod
def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
"""Convert a run to a Run."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self._endpoint = _get_endpoint()
self._headers = _get_headers()
self.tenant_id = _get_tenant_id()
def _get_session_create(
self, name: Optional[str] = None, **kwargs: Any
) -> TracerSessionBase:
return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id)
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSessionV2(id=r.json()["id"], **session_create.dict())
except Exception as e:
logging.warning(f"Failed to create session, using default session: {e}")
session = self.load_session("default")
return session
def _get_default_query_params(self) -> Dict[str, Any]:
"""Get the query params for the LangChain API."""
return {"tenant_id": self.tenant_id}
def load_session(self, session_name: str) -> TracerSessionV2:
"""Load a session with the given name from the tracer."""
try:
url = f"{self._endpoint}/sessions"
params = {"tenant_id": self.tenant_id}
if session_name:
params["name"] = session_name
r = requests.get(url, headers=self._headers, params=params)
tracer_session = TracerSessionV2(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logging.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id)
self.session = tracer_session
return tracer_session
def load_default_session(self) -> TracerSessionV2:
"""Load the default tracing session and set it as the Tracer's session."""
return self.load_session("default")
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
"""Convert a run to a Run."""
session = self.session or self.load_default_session()
inputs: Dict[str, Any] = {}
outputs: Optional[Dict[str, Any]] = None
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
@@ -126,30 +206,30 @@ class LangChainTracerV2(LangChainTracer):
return Run(
id=run.uuid,
name=run.serialized.get("name"),
name=run.serialized.get("name", f"{run_type}-{run.uuid}"),
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra,
extra=run.extra or {},
error=run.error,
execution_order=run.execution_order,
serialized=run.serialized,
inputs=inputs,
outputs=outputs,
session_id=run.session_id,
session_id=session.id,
run_type=run_type,
parent_run_id=run.parent_uuid,
child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs],
child_runs=[self._convert_run(child) for child in child_runs],
)
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
run_create = self._convert_run(run)
try:
requests.post(
result = requests.post(
f"{self._endpoint}/runs",
data=run_create.json(),
headers=self._headers,
)
result.raise_for_status()
except Exception as e:
logging.warning(f"Failed to persist run: {e}")

View File

@@ -31,6 +31,24 @@ class TracerSession(TracerSessionBase):
id: int
class TracerSessionV2Base(TracerSessionBase):
"""A creation class for TracerSessionV2."""
tenant_id: UUID
class TracerSessionV2Create(TracerSessionBase):
"""A creation class for TracerSessionV2."""
pass
class TracerSessionV2(TracerSessionV2Base):
"""TracerSession schema for the V2 API."""
id: UUID
class BaseRun(BaseModel):
"""Base class for Run."""
@@ -93,9 +111,9 @@ class Run(BaseModel):
serialized: dict
inputs: dict
outputs: Optional[dict]
session_id: int
session_id: UUID
parent_run_id: Optional[UUID]
example_id: Optional[UUID]
reference_example_id: Optional[UUID]
run_type: RunTypeEnum
child_runs: List[Run] = Field(default_factory=list)

View File

@@ -6,7 +6,6 @@ from typing import Any, Iterable, List, Optional
import sqlalchemy
from sqlalchemy import (
CursorResult,
MetaData,
Table,
create_engine,
@@ -14,7 +13,7 @@ from sqlalchemy import (
select,
text,
)
from sqlalchemy.engine import Engine
from sqlalchemy.engine import CursorResult, Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable

View File

@@ -146,7 +146,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
@property
def is_single_input(self) -> bool:
"""Whether the tool only accepts a single input."""
return len(self.args) == 1
keys = {k for k in self.args if k != "kwargs"}
return len(keys) == 1
@property
def args(self) -> dict:

View File

@@ -36,7 +36,6 @@ class PythonREPLTool(BaseTool):
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool."""
if self.sanitize_input:
@@ -47,7 +46,6 @@ class PythonREPLTool(BaseTool):
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool asynchronously."""
raise NotImplementedError("PythonReplTool does not support async")

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import enum
import logging
import os
import uuid
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
@@ -13,13 +14,13 @@ from sqlalchemy.orm import Session, declarative_base, relationship
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env, get_from_env
from langchain.vectorstores.base import VectorStore
Base = declarative_base() # type: Any
ADA_TOKEN_COUNT = 1536
PGVECTOR_VECTOR_SIZE = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
@@ -79,13 +80,28 @@ class EmbeddingStore(BaseModel):
)
collection = relationship(CollectionStore, back_populates="embeddings")
embedding: Vector = sqlalchemy.Column(Vector(ADA_TOKEN_COUNT))
embedding: Vector = sqlalchemy.Column(Vector(PGVECTOR_VECTOR_SIZE))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
def __init__(
self, *args: Any, vector_size: Optional[int] = None, **kwargs: Any
) -> None:
if "embedding" not in kwargs:
if vector_size is None:
vector_size = int(
get_from_env(
"vector_size",
"PGVECTOR_VECTOR_SIZE",
default=str(PGVECTOR_VECTOR_SIZE),
)
)
kwargs["embedding"] = sqlalchemy.Column(Vector(vector_size))
super().__init__(*args, **kwargs)
class QueryResult:
EmbeddingStore: EmbeddingStore

View File

@@ -13,7 +13,7 @@ langchain-server = "langchain.server:main"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
pydantic = "^1"
SQLAlchemy = ">=1.3,<3"
SQLAlchemy = ">=1.4,<3"
requests = "^2"
PyYAML = ">=5.4.1"
numpy = "^1"

View File

@@ -1,11 +1,8 @@
"""Test tool utils."""
from datetime import datetime
from functools import partial
from typing import Any, Optional, Type, Union
from typing import Any, Type
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel
from langchain.agents.agent import Agent
from langchain.agents.chat.base import ChatAgent
@@ -15,383 +12,6 @@ from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool
def test_unnamed_decorator() -> None:
"""Test functionality with unnamed decorator."""
@tool
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert not search_api.return_direct
assert search_api("test") == "API result"
class _MockSchema(BaseModel):
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseTool):
name = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
assert isinstance(structured_api, BaseTool)
assert structured_api.name == "structured_api"
expected_result = "1 True {'foo': 'bar'}"
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
assert structured_api.run(args) == expected_result
def test_unannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool without type hints raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _UnAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_misannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _MisAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema: BaseModel = _MockSchema # type: ignore
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_forward_ref_annotated_base_tool_accepted() -> None:
"""Test that a using forward ref annotation syntax is accepted.""" ""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: "Type[BaseModel]" = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_subclass_annotated_base_tool_accepted() -> None:
"""Test BaseTool child w/ custom schema isn't overwritten."""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: Type[_MockSchema] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
tool = _ForwardRefAnnotatedTool()
assert tool.args_schema == _MockSchema
def test_decorator_with_specified_schema() -> None:
"""Test that manually specified schemata are passed through to the tool."""
@tool(args_schema=_MockSchema)
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, BaseTool)
assert tool_func.args_schema == _MockSchema
def test_decorated_function_schema_equivalent() -> None:
"""Test that a BaseTool without a schema meets expectations."""
@tool
def structured_tool_input(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.args_schema is not None
assert (
structured_tool_input.args_schema.schema()["properties"]
== _MockSchema.schema()["properties"]
== structured_tool_input.args
)
def test_structured_args_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
with pytest.raises(ValueError):
assert structured_tool_input.run(args) == expected_result
def test_structured_single_str_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def unstructured_tool_input(tool_input: str) -> str:
"""Return the arguments directly."""
return f"{tool_input}"
assert isinstance(unstructured_tool_input, BaseTool)
assert unstructured_tool_input.args_schema is None
assert unstructured_tool_input.run("foo") == "foo"
def test_base_tool_inheritance_base_schema() -> None:
"""Test schema is correctly inferred when inheriting from BaseTool."""
class _MockSimpleTool(BaseTool):
name = "simple_tool"
description = "A Simple Tool"
def _run(self, tool_input: str) -> str:
return f"{tool_input}"
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError
simple_tool = _MockSimpleTool()
assert simple_tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
assert simple_tool.args == expected_args
def test_tool_lambda_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
name="tool",
description="A tool",
func=lambda tool_input: tool_input,
)
assert tool.args_schema is None
expected_args = {"tool_input": {"type": "string"}}
assert tool.args == expected_args
def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = StructuredTool.from_function(
name="tool",
description="A tool",
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
)
assert tool.args_schema is not None
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
}
assert tool.args == expected_args
def test_tool_partial_function_args_schema() -> None:
"""Test args schema inference when the tool argument is a partial function."""
def func(tool_input: str, other_arg: str) -> str:
return tool_input + other_arg
tool = Tool(
name="tool",
description="A tool",
func=partial(func, other_arg="foo"),
)
assert tool.run("bar") == "barfoo"
def test_empty_args_decorator() -> None:
"""Test inferred schema of decorated fn with no args."""
@tool
def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, BaseTool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result"
def test_named_tool_decorator() -> None:
"""Test functionality when arguments are provided as input to decorator."""
@tool("search")
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert not search_api.return_direct
def test_named_tool_decorator_return_direct() -> None:
"""Test functionality when arguments and return direct are provided as input."""
@tool("search", return_direct=True)
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert search_api.return_direct
def test_unnamed_tool_decorator_return_direct() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert search_api.return_direct
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
def search_api(
arg_0: str,
arg_1: float = 4.3,
ping: str = "hi",
) -> str:
"""Search the API for the query."""
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, BaseTool)
result = search_api.run(
tool_input={
"arg_0": "foo",
"arg_1": 3.2,
"ping": "pong",
}
)
assert result == "arg_0=foo, arg_1=3.2, ping=pong"
result = search_api.run(
tool_input={
"arg_0": "foo",
}
)
assert result == "arg_0=foo, arg_1=4.3, ping=hi"
# For backwards compatibility, we still accept a single str arg
result = search_api.run("foobar")
assert result == "arg_0=foobar, arg_1=4.3, ping=hi"
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if theres no docstring
with pytest.raises(AssertionError, match="Function must have a docstring"):
@tool
def search_api(query: str) -> str:
return "API result"
def test_create_tool_positional_args() -> None:
"""Test that positional arguments are allowed."""
test_tool = Tool("test_name", lambda x: x, "test_description")
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
assert test_tool.is_single_input
def test_create_tool_keyword_args() -> None:
"""Test that keyword arguments are allowed."""
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
assert test_tool.is_single_input
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
@pytest.mark.asyncio
async def test_create_async_tool() -> None:
"""Test that async tools are allowed."""
async def _test_func(x: str) -> str:
return x
test_tool = Tool(
name="test_name",
func=lambda x: x,
description="test_description",
coroutine=_test_func,
)
assert test_tool.is_single_input
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
assert test_tool.coroutine is not None
assert await test_tool.arun("foo") == "foo"
@pytest.mark.parametrize(

View File

@@ -2,8 +2,9 @@
from __future__ import annotations
from datetime import datetime
from typing import List, Union
from uuid import uuid4
from typing import List, Tuple, Union
from unittest.mock import Mock, patch
from uuid import UUID, uuid4
import pytest
from freezegun import freeze_time
@@ -16,7 +17,8 @@ from langchain.callbacks.tracers.base import (
TracerException,
TracerSession,
)
from langchain.callbacks.tracers.schemas import TracerSessionCreate
from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.callbacks.tracers.schemas import Run, TracerSessionBase, TracerSessionV2
from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
@@ -27,7 +29,7 @@ def load_session(session_name: str) -> TracerSession:
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
def _persist_session(session: TracerSessionCreate) -> TracerSession:
def _persist_session(session: TracerSessionBase) -> TracerSession:
"""Persist a tracing session."""
return TracerSession(id=TEST_SESSION_ID, **session.dict())
@@ -49,7 +51,7 @@ class FakeTracer(BaseTracer):
"""Persist a run."""
self.runs.append(run)
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
def _persist_session(self, session: TracerSessionBase) -> TracerSession:
"""Persist a tracing session."""
return _persist_session(session)
@@ -473,3 +475,125 @@ def test_tracer_nested_runs_on_error() -> None:
)
assert tracer.runs == [compare_run] * 3
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
@pytest.fixture
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV2:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV2()
return tracer
# Mock a sample TracerSessionV2 object
@pytest.fixture
def sample_tracer_session_v2() -> TracerSessionV2:
return TracerSessionV2(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
# Mock a sample LLMRun, ChainRun, and ToolRun objects
@pytest.fixture
def sample_runs() -> Tuple[LLMRun, ChainRun, ToolRun]:
llm_run = LLMRun(
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
session_id=1,
prompts=[],
response=LLMResult(generations=[[]]),
serialized={},
extra={},
)
chain_run = ChainRun(
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
child_execution_order=1,
session_id=1,
serialized={},
inputs={},
outputs=None,
child_llm_runs=[llm_run],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
tool_run = ToolRun(
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
session_id=1,
tool_input="test",
action="{}",
serialized={},
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
return llm_run, chain_run, tool_run
# Test _get_default_query_params method
def test_get_default_query_params(lang_chain_tracer_v2: LangChainTracerV2) -> None:
expected = {"tenant_id": "test-tenant-id"}
result = lang_chain_tracer_v2._get_default_query_params()
assert result == expected
# Test load_session method
@patch("langchain.callbacks.tracers.langchain.requests.get")
def test_load_session(
mock_requests_get: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
) -> None:
"""Test that load_session method returns a TracerSessionV2 object."""
mock_requests_get.return_value.json.return_value = [sample_tracer_session_v2.dict()]
result = lang_chain_tracer_v2.load_session("test-session-name")
mock_requests_get.assert_called_with(
"http://test-endpoint.com/sessions",
headers={"Content-Type": "application/json", "x-api-key": "foo"},
params={"tenant_id": "test-tenant-id", "name": "test-session-name"},
)
assert result == sample_tracer_session_v2
def test_convert_run(
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None:
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2.session = sample_tracer_session_v2
converted_llm_run = lang_chain_tracer_v2._convert_run(llm_run)
converted_chain_run = lang_chain_tracer_v2._convert_run(chain_run)
converted_tool_run = lang_chain_tracer_v2._convert_run(tool_run)
assert isinstance(converted_llm_run, Run)
assert isinstance(converted_chain_run, Run)
assert isinstance(converted_tool_run, Run)
@patch("langchain.callbacks.tracers.langchain.requests.post")
def test_persist_run(
mock_requests_post: Mock,
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None:
mock_requests_post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
llm_run, chain_run, tool_run = sample_runs
lang_chain_tracer_v2._persist_run(llm_run)
lang_chain_tracer_v2._persist_run(chain_run)
lang_chain_tracer_v2._persist_run(tool_run)
assert mock_requests_post.call_count == 3

View File

@@ -0,0 +1,23 @@
"""Test Python REPL Tools."""
import sys
import pytest
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
def test_python_repl_tool_single_input() -> None:
"""Test that the python REPL tool works with a single input."""
tool = PythonREPLTool()
assert tool.is_single_input
assert int(tool.run("print(1 + 1)").strip()) == 2
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_tool_single_input() -> None:
"""Test that the python REPL tool works with a single input."""
tool = PythonAstREPLTool()
assert tool.is_single_input
assert tool.run("1 + 1") == 2

View File

@@ -0,0 +1,438 @@
"""Test the base tool implementation."""
from datetime import datetime
from functools import partial
from typing import Any, Optional, Type, Union
import pytest
from pydantic import BaseModel
from langchain.agents.tools import Tool, tool
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool
def test_unnamed_decorator() -> None:
"""Test functionality with unnamed decorator."""
@tool
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert not search_api.return_direct
assert search_api("test") == "API result"
class _MockSchema(BaseModel):
arg1: int
arg2: bool
arg3: Optional[dict] = None
class _MockStructuredTool(BaseTool):
name = "structured_api"
args_schema: Type[BaseModel] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
raise NotImplementedError
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
assert isinstance(structured_api, BaseTool)
assert structured_api.name == "structured_api"
expected_result = "1 True {'foo': 'bar'}"
args = {"arg1": 1, "arg2": True, "arg3": {"foo": "bar"}}
assert structured_api.run(args) == expected_result
def test_unannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool without type hints raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _UnAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_misannotated_base_tool_raises_error() -> None:
"""Test that a BaseTool with the incorrrect typehint raises an exception.""" ""
with pytest.raises(SchemaAnnotationError):
class _MisAnnotatedTool(BaseTool):
name = "structured_api"
# This would silently be ignored without the custom metaclass
args_schema: BaseModel = _MockSchema # type: ignore
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_forward_ref_annotated_base_tool_accepted() -> None:
"""Test that a using forward ref annotation syntax is accepted.""" ""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: "Type[BaseModel]" = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
def test_subclass_annotated_base_tool_accepted() -> None:
"""Test BaseTool child w/ custom schema isn't overwritten."""
class _ForwardRefAnnotatedTool(BaseTool):
name = "structured_api"
args_schema: Type[_MockSchema] = _MockSchema
description = "A Structured Tool"
def _run(self, arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
return f"{arg1} {arg2} {arg3}"
async def _arun(
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
raise NotImplementedError
assert issubclass(_ForwardRefAnnotatedTool, BaseTool)
tool = _ForwardRefAnnotatedTool()
assert tool.args_schema == _MockSchema
def test_decorator_with_specified_schema() -> None:
"""Test that manually specified schemata are passed through to the tool."""
@tool(args_schema=_MockSchema)
def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func, BaseTool)
assert tool_func.args_schema == _MockSchema
def test_decorated_function_schema_equivalent() -> None:
"""Test that a BaseTool without a schema meets expectations."""
@tool
def structured_tool_input(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.args_schema is not None
assert (
structured_tool_input.args_schema.schema()["properties"]
== _MockSchema.schema()["properties"]
== structured_tool_input.args
)
def test_args_kwargs_filtered() -> None:
class _SingleArgToolWithKwargs(BaseTool):
name = "single_arg_tool"
description = "A single arged tool with kwargs"
def _run(
self,
some_arg: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
return "foo"
async def _arun(
self,
some_arg: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
raise NotImplementedError
tool = _SingleArgToolWithKwargs()
assert tool.is_single_input
class _VarArgToolWithKwargs(BaseTool):
name = "single_arg_tool"
description = "A single arged tool with kwargs"
def _run(
self,
*args: Any,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
return "foo"
async def _arun(
self,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> str:
raise NotImplementedError
tool2 = _VarArgToolWithKwargs()
assert tool2.is_single_input
def test_structured_args_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
with pytest.raises(ValueError):
assert structured_tool_input.run(args) == expected_result
def test_structured_single_str_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""
@tool(infer_schema=False)
def unstructured_tool_input(tool_input: str) -> str:
"""Return the arguments directly."""
return f"{tool_input}"
assert isinstance(unstructured_tool_input, BaseTool)
assert unstructured_tool_input.args_schema is None
assert unstructured_tool_input.run("foo") == "foo"
def test_base_tool_inheritance_base_schema() -> None:
"""Test schema is correctly inferred when inheriting from BaseTool."""
class _MockSimpleTool(BaseTool):
name = "simple_tool"
description = "A Simple Tool"
def _run(self, tool_input: str) -> str:
return f"{tool_input}"
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError
simple_tool = _MockSimpleTool()
assert simple_tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input", "type": "string"}}
assert simple_tool.args == expected_args
def test_tool_lambda_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
name="tool",
description="A tool",
func=lambda tool_input: tool_input,
)
assert tool.args_schema is None
expected_args = {"tool_input": {"type": "string"}}
assert tool.args == expected_args
def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = StructuredTool.from_function(
name="tool",
description="A tool",
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
)
assert tool.args_schema is not None
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
}
assert tool.args == expected_args
def test_tool_partial_function_args_schema() -> None:
"""Test args schema inference when the tool argument is a partial function."""
def func(tool_input: str, other_arg: str) -> str:
return tool_input + other_arg
tool = Tool(
name="tool",
description="A tool",
func=partial(func, other_arg="foo"),
)
assert tool.run("bar") == "barfoo"
def test_empty_args_decorator() -> None:
"""Test inferred schema of decorated fn with no args."""
@tool
def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"
assert isinstance(empty_tool_input, BaseTool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result"
def test_named_tool_decorator() -> None:
"""Test functionality when arguments are provided as input to decorator."""
@tool("search")
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert not search_api.return_direct
def test_named_tool_decorator_return_direct() -> None:
"""Test functionality when arguments and return direct are provided as input."""
@tool("search", return_direct=True)
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert search_api.return_direct
def test_unnamed_tool_decorator_return_direct() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert search_api.return_direct
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
def search_api(
arg_0: str,
arg_1: float = 4.3,
ping: str = "hi",
) -> str:
"""Search the API for the query."""
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"
assert isinstance(search_api, BaseTool)
result = search_api.run(
tool_input={
"arg_0": "foo",
"arg_1": 3.2,
"ping": "pong",
}
)
assert result == "arg_0=foo, arg_1=3.2, ping=pong"
result = search_api.run(
tool_input={
"arg_0": "foo",
}
)
assert result == "arg_0=foo, arg_1=4.3, ping=hi"
# For backwards compatibility, we still accept a single str arg
result = search_api.run("foobar")
assert result == "arg_0=foobar, arg_1=4.3, ping=hi"
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if theres no docstring
with pytest.raises(AssertionError, match="Function must have a docstring"):
@tool
def search_api(query: str) -> str:
return "API result"
def test_create_tool_positional_args() -> None:
"""Test that positional arguments are allowed."""
test_tool = Tool("test_name", lambda x: x, "test_description")
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
assert test_tool.is_single_input
def test_create_tool_keyword_args() -> None:
"""Test that keyword arguments are allowed."""
test_tool = Tool(name="test_name", func=lambda x: x, description="test_description")
assert test_tool.is_single_input
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
@pytest.mark.asyncio
async def test_create_async_tool() -> None:
"""Test that async tools are allowed."""
async def _test_func(x: str) -> str:
return x
test_tool = Tool(
name="test_name",
func=lambda x: x,
description="test_description",
coroutine=_test_func,
)
assert test_tool.is_single_input
assert test_tool("foo") == "foo"
assert test_tool.name == "test_name"
assert test_tool.description == "test_description"
assert test_tool.coroutine is not None
assert await test_tool.arun("foo") == "foo"

View File

@@ -70,7 +70,7 @@ def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None:
assert file_contents is None
file_contents = Path(file_path).read_text()
mocked_responses.get(
mocked_responses.get( # type: ignore
urljoin(URL_BASE.format(ref=ref), path),
body=body,
status=200,
@@ -86,7 +86,9 @@ def test_failed_request(mocked_responses: responses.RequestsMock) -> None:
path = "chains/path/chain.json"
loader = Mock()
mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500)
mocked_responses.get( # type: ignore
urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500
)
with pytest.raises(ValueError, match=re.compile("Could not find file at .*")):
try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})

View File

@@ -0,0 +1,22 @@
import os
from langchain.vectorstores.pgvector import PGVECTOR_VECTOR_SIZE, EmbeddingStore
def test_embedding_store_init_defaults() -> None:
expected = PGVECTOR_VECTOR_SIZE
actual = EmbeddingStore().embedding.type.dim
assert expected == actual
def test_embedding_store_init_vector_size() -> None:
expected = 2
actual = EmbeddingStore(vector_size=2).embedding.type.dim
assert expected == actual
def test_embedding_store_init_env_vector_size() -> None:
os.environ["PGVECTOR_VECTOR_SIZE"] = "3"
expected = 3
actual = EmbeddingStore().embedding.type.dim
assert expected == actual