feat(langchain): v1 scaffolding (#32166)

This PR adds scaffolding for langchain 1.0 entry package.

Most contents have been removed. 

Currently remaining entrypoints for:

* chat models
* embedding models
* memory -> trimming messages, filtering messages and counting tokens
[we may remove this]
* prompts -> we may remove some prompts
* storage: primarily to support cache backed embeddings, may remove the
kv store
* tools -> report tool primitives

Things to be added:

* Selected agent implementations
* Selected workflows
* Common primitives: messages, Document
* Primitives for type hinting: BaseChatModel, BaseEmbeddings
* Selected retrievers
* Selected text splitters

Things to be removed:

* Globals needs to be removed (needs an update in langchain core)


Todos: 

* TBD indexing api (requires sqlalchemy which we don't want as a
dependency)
* Be explicit about public/private interfaces (e.g., likely rename
chat_models.base.py to something more internal)
* Remove dockerfiles
* Update module doc-strings and README.md
This commit is contained in:
Eugene Yurtsev
2025-07-24 09:47:48 -04:00
committed by GitHub
parent bd3d6496f3
commit 56dde3ade3
57 changed files with 8548 additions and 0 deletions

View File

@@ -16,6 +16,7 @@ LANGCHAIN_DIRS = [
"libs/core",
"libs/text-splitters",
"libs/langchain",
"libs/langchain_v1",
]
# when set to True, we are ignoring core dependents

View File

@@ -0,0 +1,6 @@
.venv
.github
.git
.mypy_cache
.pytest_cache
Dockerfile

21
libs/langchain_v1/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

101
libs/langchain_v1/Makefile Normal file
View File

@@ -0,0 +1,101 @@
.PHONY: all clean docs_build docs_clean docs_linkcheck api_docs_build api_docs_clean api_docs_linkcheck format lint test tests test_watch integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
######################
# TESTING AND COVERAGE
######################
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
.EXPORT_ALL_VARIABLES:
UV_FROZEN = true
# Run unit tests and generate a coverage report.
coverage:
uv run --group test pytest --cov \
--cov-config=.coveragerc \
--cov-report xml \
--cov-report term-missing:skip-covered \
$(TEST_FILE)
test tests:
uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
extended_tests:
uv run --group test pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
test_watch:
uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests
test_watch_extended:
uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests
integration_tests:
uv run --group test --group test_integration pytest tests/integration_tests
docker_tests:
docker build -t my-langchain-image:test .
docker run --rm my-langchain-image:test
check_imports: $(shell find langchain -name '*.py')
uv run python ./scripts/check_imports.py $^
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --all-groups mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check --fix $(PYTHON_FILES)
spell_check:
uv run --all-groups codespell --toml pyproject.toml
spell_fix:
uv run --all-groups codespell --toml pyproject.toml -w
######################
# HELP
######################
help:
@echo '===================='
@echo 'clean - run docs_clean and api_docs_clean'
@echo 'docs_build - build the documentation'
@echo 'docs_clean - clean the documentation build artifacts'
@echo 'docs_linkcheck - run linkchecker on the documentation'
@echo 'api_docs_build - build the API Reference documentation'
@echo 'api_docs_clean - clean the API Reference documentation build artifacts'
@echo 'api_docs_linkcheck - run linkchecker on the API Reference documentation'
@echo '-- LINTING --'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'spell_check - run codespell on the project'
@echo 'spell_fix - run codespell on the project and fix the errors'
@echo '-- TESTS --'
@echo 'coverage - run unit tests and generate coverage report'
@echo 'test - run unit tests'
@echo 'tests - run unit tests (alias for "make test")'
@echo 'test TEST_FILE=<test_file> - run all tests in file'
@echo 'extended_tests - run only extended unit tests'
@echo 'test_watch - run unit tests in watch mode'
@echo 'integration_tests - run integration tests'
@echo 'docker_tests - run unit tests in docker'
@echo '-- DOCUMENTATION tasks are from the top-level Makefile --'

View File

@@ -0,0 +1,91 @@
# 🦜️🔗 LangChain
⚡ Building applications with LLMs through composability ⚡
[![Release Notes](https://img.shields.io/github/release/langchain-ai/langchain)](https://github.com/langchain-ai/langchain/releases)
[![lint](https://github.com/langchain-ai/langchain/actions/workflows/lint.yml/badge.svg)](https://github.com/langchain-ai/langchain/actions/workflows/lint.yml)
[![test](https://github.com/langchain-ai/langchain/actions/workflows/test.yml/badge.svg)](https://github.com/langchain-ai/langchain/actions/workflows/test.yml)
[![Downloads](https://static.pepy.tech/badge/langchain/month)](https://pepy.tech/project/langchain)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Twitter](https://img.shields.io/twitter/url/https/twitter.com/langchainai.svg?style=social&label=Follow%20%40LangChainAI)](https://twitter.com/langchainai)
[![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langchain-ai/langchain)
[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/langchain-ai/langchain)
[![GitHub star chart](https://img.shields.io/github/stars/langchain-ai/langchain?style=social)](https://star-history.com/#langchain-ai/langchain)
[![Dependency Status](https://img.shields.io/librariesio/github/langchain-ai/langchain)](https://libraries.io/github/langchain-ai/langchain)
[![Open Issues](https://img.shields.io/github/issues-raw/langchain-ai/langchain)](https://github.com/langchain-ai/langchain/issues)
Looking for the JS/TS version? Check out [LangChain.js](https://github.com/langchain-ai/langchainjs).
To help you ship LangChain apps to production faster, check out [LangSmith](https://smith.langchain.com).
[LangSmith](https://smith.langchain.com) is a unified developer platform for building, testing, and monitoring LLM applications.
Fill out [this form](https://www.langchain.com/contact-sales) to speak with our sales team.
## Quick Install
`pip install langchain`
or
`pip install langsmith && conda install langchain -c conda-forge`
## 🤔 What is this?
Large language models (LLMs) are emerging as a transformative technology, enabling developers to build applications that they previously could not. However, using these LLMs in isolation is often insufficient for creating a truly powerful app - the real power comes when you can combine them with other sources of computation or knowledge.
This library aims to assist in the development of those types of applications. Common examples of these applications include:
**❓ Question answering with RAG**
- [Documentation](https://python.langchain.com/docs/use_cases/question_answering/)
- End-to-end Example: [Chat LangChain](https://chat.langchain.com) and [repo](https://github.com/langchain-ai/chat-langchain)
**🧱 Extracting structured output**
- [Documentation](https://python.langchain.com/docs/use_cases/extraction/)
- End-to-end Example: [SQL Llama2 Template](https://github.com/langchain-ai/langchain-extract/)
**🤖 Chatbots**
- [Documentation](https://python.langchain.com/docs/use_cases/chatbots)
- End-to-end Example: [Web LangChain (web researcher chatbot)](https://weblangchain.vercel.app) and [repo](https://github.com/langchain-ai/weblangchain)
## 📖 Documentation
Please see [here](https://python.langchain.com) for full documentation on:
- Getting started (installation, setting up the environment, simple examples)
- How-To examples (demos, integrations, helper functions)
- Reference (full API docs)
- Resources (high-level explanation of core concepts)
## 🚀 What can this help with?
There are five main areas that LangChain is designed to help with.
These are, in increasing order of complexity:
**📃 Models and Prompts:**
This includes prompt management, prompt optimization, a generic interface for all LLMs, and common utilities for working with chat models and LLMs.
**🔗 Chains:**
Chains go beyond a single LLM call and involve sequences of calls (whether to an LLM or a different utility). LangChain provides a standard interface for chains, lots of integrations with other tools, and end-to-end chains for common applications.
**📚 Retrieval Augmented Generation:**
Retrieval Augmented Generation involves specific types of chains that first interact with an external data source to fetch data for use in the generation step. Examples include summarization of long pieces of text and question/answering over specific data sources.
**🤖 Agents:**
Agents involve an LLM making decisions about which Actions to take, taking that Action, seeing an Observation, and repeating that until done. LangChain provides a standard interface for agents, a selection of agents to choose from, and examples of end-to-end agents.
**🧐 Evaluation:**
[BETA] Generative models are notoriously hard to evaluate with traditional metrics. One new way of evaluating them is using language models themselves to do the evaluation. LangChain provides some prompts/chains for assisting in this.
For more information on these concepts, please see our [full documentation](https://python.langchain.com).
## 💁 Contributing
As an open-source project in a rapidly developing field, we are extremely open to contributions, whether it be in the form of a new feature, improved infrastructure, or better documentation.
For detailed information on how to contribute, see the [Contributing Guide](https://python.langchain.com/docs/contributing/).

View File

@@ -0,0 +1,5 @@
-e ../partners/openai
-e ../partners/anthropic
-e ../partners/fireworks
-e ../partners/mistralai
-e ../partners/groq

View File

@@ -0,0 +1,29 @@
"""Main entrypoint into package."""
from importlib import metadata
from typing import Any
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""
del metadata # optional, avoids polluting the results of dir(__package__)
def __getattr__(name: str) -> Any: # noqa: ANN401
"""Get an attribute from the package."""
if name == "verbose":
from langchain.globals import _verbose
return _verbose
if name == "debug":
from langchain.globals import _debug
return _debug
if name == "llm_cache":
from langchain.globals import _llm_cache
return _llm_cache
msg = f"Could not find: {name}"
raise AttributeError(msg)

View File

@@ -0,0 +1,24 @@
"""**Chat Models** are a variation on language models.
While Chat Models use language models under the hood, the interface they expose
is a bit different. Rather than expose a "text in, text out" API, they expose
an interface where "chat messages" are the inputs and outputs.
**Class hierarchy:**
.. code-block::
BaseLanguageModel --> BaseChatModel --> <name> # Examples: ChatOpenAI, ChatGooglePalm
**Main helpers:**
.. code-block::
AIMessage, BaseMessage, HumanMessage
""" # noqa: E501
from langchain.chat_models.base import init_chat_model
__all__ = [
"init_chat_model",
]

View File

@@ -0,0 +1,946 @@
from __future__ import annotations
import warnings
from importlib import util
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Union,
cast,
overload,
)
from langchain_core.language_models import (
BaseChatModel,
LanguageModelInput,
)
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from typing_extensions import TypeAlias, override
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
from langchain_core.tracers import RunLog, RunLogPatch
from pydantic import BaseModel
@overload
def init_chat_model(
model: str,
*,
model_provider: Optional[str] = None,
configurable_fields: Literal[None] = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> BaseChatModel: ...
@overload
def init_chat_model(
model: Literal[None] = None,
*,
model_provider: Optional[str] = None,
configurable_fields: Literal[None] = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
@overload
def init_chat_model(
model: Optional[str] = None,
*,
model_provider: Optional[str] = None,
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
# name to the supported list in the docstring below. Do *not* change the order of the
# existing providers.
def init_chat_model(
model: Optional[str] = None,
*,
model_provider: Optional[str] = None,
configurable_fields: Optional[
Union[Literal["any"], list[str], tuple[str, ...]]
] = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> Union[BaseChatModel, _ConfigurableModel]:
"""Initialize a ChatModel from the model name and provider.
**Note:** Must have the integration package corresponding to the model provider
installed.
Args:
model: The name of the model, e.g. "o3-mini", "claude-3-5-sonnet-latest". You can
also specify model and model provider in a single argument using
'{model_provider}:{model}' format, e.g. "openai:o1".
model_provider: The model provider if not specified as part of model arg (see
above). Supported model_provider values and the corresponding integration
package are:
- 'openai' -> langchain-openai
- 'anthropic' -> langchain-anthropic
- 'azure_openai' -> langchain-openai
- 'azure_ai' -> langchain-azure-ai
- 'google_vertexai' -> langchain-google-vertexai
- 'google_genai' -> langchain-google-genai
- 'bedrock' -> langchain-aws
- 'bedrock_converse' -> langchain-aws
- 'cohere' -> langchain-cohere
- 'fireworks' -> langchain-fireworks
- 'together' -> langchain-together
- 'mistralai' -> langchain-mistralai
- 'huggingface' -> langchain-huggingface
- 'groq' -> langchain-groq
- 'ollama' -> langchain-ollama
- 'google_anthropic_vertex' -> langchain-google-vertexai
- 'deepseek' -> langchain-deepseek
- 'ibm' -> langchain-ibm
- 'nvidia' -> langchain-nvidia-ai-endpoints
- 'xai' -> langchain-xai
- 'perplexity' -> langchain-perplexity
Will attempt to infer model_provider from model if not specified. The
following providers will be inferred based on these model prefixes:
- 'gpt-3...' | 'gpt-4...' | 'o1...' -> 'openai'
- 'claude...' -> 'anthropic'
- 'amazon....' -> 'bedrock'
- 'gemini...' -> 'google_vertexai'
- 'command...' -> 'cohere'
- 'accounts/fireworks...' -> 'fireworks'
- 'mistral...' -> 'mistralai'
- 'deepseek...' -> 'deepseek'
- 'grok...' -> 'xai'
- 'sonar...' -> 'perplexity'
configurable_fields: Which model parameters are
configurable:
- None: No configurable fields.
- "any": All fields are configurable. *See Security Note below.*
- Union[List[str], Tuple[str, ...]]: Specified fields are configurable.
Fields are assumed to have config_prefix stripped if there is a
config_prefix. If model is specified, then defaults to None. If model is
not specified, then defaults to ``("model", "model_provider")``.
***Security Note***: Setting ``configurable_fields="any"`` means fields like
api_key, base_url, etc. can be altered at runtime, potentially redirecting
model requests to a different service/user. Make sure that if you're
accepting untrusted configurations that you enumerate the
``configurable_fields=(...)`` explicitly.
config_prefix: If config_prefix is a non-empty string then model will be
configurable at runtime via the
``config["configurable"]["{config_prefix}_{param}"]`` keys. If
config_prefix is an empty string then model will be configurable via
``config["configurable"]["{param}"]``.
temperature: Model temperature.
max_tokens: Max output tokens.
timeout: The maximum time (in seconds) to wait for a response from the model
before canceling the request.
max_retries: The maximum number of attempts the system will make to resend a
request if it fails due to issues like network timeouts or rate limits.
base_url: The URL of the API endpoint where requests are sent.
rate_limiter: A ``BaseRateLimiter`` to space out requests to avoid exceeding
rate limits.
kwargs: Additional model-specific keyword args to pass to
``<<selected ChatModel>>.__init__(model=model_name, **kwargs)``.
Returns:
A BaseChatModel corresponding to the model_name and model_provider specified if
configurability is inferred to be False. If configurable, a chat model emulator
that initializes the underlying model at runtime once a config is passed in.
Raises:
ValueError: If model_provider cannot be inferred or isn't supported.
ImportError: If the model provider integration package is not installed.
.. dropdown:: Init non-configurable model
:open:
.. code-block:: python
# pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
from langchain.chat_models import init_chat_model
o3_mini = init_chat_model("openai:o3-mini", temperature=0)
claude_sonnet = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=0)
gemini_2_flash = init_chat_model("google_vertexai:gemini-2.0-flash", temperature=0)
o3_mini.invoke("what's your name")
claude_sonnet.invoke("what's your name")
gemini_2_flash.invoke("what's your name")
.. dropdown:: Partially configurable model with no default
.. code-block:: python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
# We don't need to specify configurable=True if a model isn't specified.
configurable_model = init_chat_model(temperature=0)
configurable_model.invoke(
"what's your name",
config={"configurable": {"model": "gpt-4o"}}
)
# GPT-4o response
configurable_model.invoke(
"what's your name",
config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
)
# claude-3.5 sonnet response
.. dropdown:: Fully configurable model with a default
.. code-block:: python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
configurable_model_with_default = init_chat_model(
"openai:gpt-4o",
configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime.
config_prefix="foo",
temperature=0
)
configurable_model_with_default.invoke("what's your name")
# GPT-4o response with temperature 0
configurable_model_with_default.invoke(
"what's your name",
config={
"configurable": {
"foo_model": "anthropic:claude-3-5-sonnet-20240620",
"foo_temperature": 0.6
}
}
)
# Claude-3.5 sonnet response with temperature 0.6
.. dropdown:: Bind tools to a configurable model
You can call any ChatModel declarative methods on a configurable model in the
same way that you would with a normal model.
.. code-block:: python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
configurable_model = init_chat_model(
"gpt-4o",
configurable_fields=("model", "model_provider"),
temperature=0
)
configurable_model_with_tools = configurable_model.bind_tools([GetWeather, GetPopulation])
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
# GPT-4o response with tool calls
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?",
config={"configurable": {"model": "claude-3-5-sonnet-20240620"}}
)
# Claude-3.5 sonnet response with tools
.. versionadded:: 0.2.7
.. versionchanged:: 0.2.8
Support for ``configurable_fields`` and ``config_prefix`` added.
.. versionchanged:: 0.2.12
Support for Ollama via langchain-ollama package added
(langchain_ollama.ChatOllama). Previously,
the now-deprecated langchain-community version of Ollama was imported
(langchain_community.chat_models.ChatOllama).
Support for AWS Bedrock models via the Converse API added
(model_provider="bedrock_converse").
.. versionchanged:: 0.3.5
Out of beta.
.. versionchanged:: 0.3.19
Support for Deepseek, IBM, Nvidia, and xAI models added.
""" # noqa: E501
if not model and not configurable_fields:
configurable_fields = ("model", "model_provider")
config_prefix = config_prefix or ""
if config_prefix and not configurable_fields:
warnings.warn(
f"{config_prefix=} has been set but no fields are configurable. Set "
f"`configurable_fields=(...)` to specify the model params that are "
f"configurable.",
stacklevel=2,
)
if not configurable_fields:
return _init_chat_model_helper(
cast("str", model),
model_provider=model_provider,
**kwargs,
)
if model:
kwargs["model"] = model
if model_provider:
kwargs["model_provider"] = model_provider
return _ConfigurableModel(
default_config=kwargs,
config_prefix=config_prefix,
configurable_fields=configurable_fields,
)
def _init_chat_model_helper(
model: str,
*,
model_provider: Optional[str] = None,
**kwargs: Any,
) -> BaseChatModel:
model, model_provider = _parse_model(model, model_provider)
if model_provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs)
if model_provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(model=model, **kwargs)
if model_provider == "azure_ai":
_check_pkg("langchain_azure_ai")
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
return AzureAIChatCompletionsModel(model=model, **kwargs)
if model_provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
return ChatCohere(model=model, **kwargs)
if model_provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
return ChatVertexAI(model=model, **kwargs)
if model_provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model, **kwargs)
if model_provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
return ChatFireworks(model=model, **kwargs)
if model_provider == "ollama":
try:
_check_pkg("langchain_ollama")
from langchain_ollama import ChatOllama
except ImportError:
# For backwards compatibility
try:
_check_pkg("langchain_community")
from langchain_community.chat_models import ChatOllama
except ImportError:
# If both langchain-ollama and langchain-community aren't available,
# raise an error related to langchain-ollama
_check_pkg("langchain_ollama")
return ChatOllama(model=model, **kwargs)
if model_provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
return ChatTogether(model=model, **kwargs)
if model_provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
return ChatHuggingFace(model_id=model, **kwargs)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
return ChatGroq(model=model, **kwargs)
if model_provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
return ChatBedrock(model_id=model, **kwargs)
if model_provider == "bedrock_converse":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrockConverse
return ChatBedrockConverse(model=model, **kwargs)
if model_provider == "google_anthropic_vertex":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
return ChatAnthropicVertex(model=model, **kwargs)
if model_provider == "deepseek":
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(model=model, **kwargs)
if model_provider == "nvidia":
_check_pkg("langchain_nvidia_ai_endpoints")
from langchain_nvidia_ai_endpoints import ChatNVIDIA
return ChatNVIDIA(model=model, **kwargs)
if model_provider == "ibm":
_check_pkg("langchain_ibm")
from langchain_ibm import ChatWatsonx
return ChatWatsonx(model_id=model, **kwargs)
if model_provider == "xai":
_check_pkg("langchain_xai")
from langchain_xai import ChatXAI
return ChatXAI(model=model, **kwargs)
if model_provider == "perplexity":
_check_pkg("langchain_perplexity")
from langchain_perplexity import ChatPerplexity
return ChatPerplexity(model=model, **kwargs)
supported = ", ".join(_SUPPORTED_PROVIDERS)
msg = (
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
)
raise ValueError(msg)
_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"azure_ai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
"bedrock_converse",
"google_anthropic_vertex",
"deepseek",
"ibm",
"xai",
"perplexity",
}
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
return "openai"
if model_name.startswith("claude"):
return "anthropic"
if model_name.startswith("command"):
return "cohere"
if model_name.startswith("accounts/fireworks"):
return "fireworks"
if model_name.startswith("gemini"):
return "google_vertexai"
if model_name.startswith("amazon."):
return "bedrock"
if model_name.startswith("mistral"):
return "mistralai"
if model_name.startswith("deepseek"):
return "deepseek"
if model_name.startswith("grok"):
return "xai"
if model_name.startswith("sonar"):
return "perplexity"
return None
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
if (
not model_provider
and ":" in model
and model.split(":")[0] in _SUPPORTED_PROVIDERS
):
model_provider = model.split(":")[0]
model = ":".join(model.split(":")[1:])
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
msg = (
f"Unable to infer model provider for {model=}, please specify "
f"model_provider directly."
)
raise ValueError(msg)
model_provider = model_provider.replace("-", "_").lower()
return model, model_provider
def _check_pkg(pkg: str, *, pkg_kebab: Optional[str] = None) -> None:
if not util.find_spec(pkg):
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
msg = (
f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
)
raise ImportError(msg)
def _remove_prefix(s: str, prefix: str) -> str:
return s.removeprefix(prefix)
_DECLARATIVE_METHODS = ("bind_tools", "with_structured_output")
class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def __init__(
self,
*,
default_config: Optional[dict] = None,
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
config_prefix: str = "",
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
) -> None:
self._default_config: dict = default_config or {}
self._configurable_fields: Union[Literal["any"], list[str]] = (
configurable_fields
if configurable_fields == "any"
else list(configurable_fields)
)
self._config_prefix = (
config_prefix + "_"
if config_prefix and not config_prefix.endswith("_")
else config_prefix
)
self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
queued_declarative_operations,
)
def __getattr__(self, name: str) -> Any:
if name in _DECLARATIVE_METHODS:
# Declarative operations that cannot be applied until after an actual model
# object is instantiated. So instead of returning the actual operation,
# we record the operation and its arguments in a queue. This queue is
# then applied in order whenever we actually instantiate the model (in
# self._model()).
def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel:
queued_declarative_operations = list(
self._queued_declarative_operations,
)
queued_declarative_operations.append((name, args, kwargs))
return _ConfigurableModel(
default_config=dict(self._default_config),
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
return queue
if self._default_config and (model := self._model()) and hasattr(model, name):
return getattr(model, name)
msg = f"{name} is not a BaseChatModel attribute"
if self._default_config:
msg += " and is not implemented on the default model"
msg += "."
raise AttributeError(msg)
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
params = {**self._default_config, **self._model_params(config)}
model = _init_chat_model_helper(**params)
for name, args, kwargs in self._queued_declarative_operations:
model = getattr(model, name)(*args, **kwargs)
return model
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
config = ensure_config(config)
model_params = {
_remove_prefix(k, self._config_prefix): v
for k, v in config.get("configurable", {}).items()
if k.startswith(self._config_prefix)
}
if self._configurable_fields != "any":
model_params = {
k: v for k, v in model_params.items() if k in self._configurable_fields
}
return model_params
def with_config(
self,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> _ConfigurableModel:
"""Bind config to a Runnable, returning a new Runnable."""
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
model_params = self._model_params(config)
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
remaining_config["configurable"] = {
k: v
for k, v in config.get("configurable", {}).items()
if _remove_prefix(k, self._config_prefix) not in model_params
}
queued_declarative_operations = list(self._queued_declarative_operations)
if remaining_config:
queued_declarative_operations.append(
(
"with_config",
(),
{"config": remaining_config},
),
)
return _ConfigurableModel(
default_config={**self._default_config, **model_params},
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
StringPromptValue,
)
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
list[AnyMessage],
]
@override
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
return self._model(config).invoke(input, config=config, **kwargs)
@override
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Any:
return await self._model(config).ainvoke(input, config=config, **kwargs)
@override
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Any]:
yield from self._model(config).stream(input, config=config, **kwargs)
@override
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Any]:
async for x in self._model(config).astream(input, config=config, **kwargs):
yield x
def batch(
self,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return self._model(config).batch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
return super().batch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
async def abatch(
self,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return await self._model(config).abatch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
return await super().abatch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
def batch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> Iterator[tuple[int, Union[Any, Exception]]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
yield from super().batch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
async def abatch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> AsyncIterator[tuple[int, Any]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
async for x in self._model(
cast("RunnableConfig", config),
).abatch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
):
yield x
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
async for x in super().abatch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
):
yield x
@override
def transform(
self,
input: Iterator[LanguageModelInput],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Any]:
yield from self._model(config).transform(input, config=config, **kwargs)
@override
async def atransform(
self,
input: AsyncIterator[LanguageModelInput],
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Any]:
async for x in self._model(config).atransform(input, config=config, **kwargs):
yield x
@overload
def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: Literal[True] = True,
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch]: ...
@overload
def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: Literal[False],
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[RunLog]: ...
@override
async def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: bool = True,
with_streamed_output_list: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
input,
config=config,
diff=diff,
with_streamed_output_list=with_streamed_output_list,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
@override
async def astream_events(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1", "v2"] = "v2",
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
async for x in self._model(config).astream_events(
input,
config=config,
version=version,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
# Explicitly added to satisfy downstream linters.
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.__getattr__("bind_tools")(tools, **kwargs)
# Explicitly added to satisfy downstream linters.
def with_structured_output(
self,
schema: Union[dict, type[BaseModel]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
return self.__getattr__("with_structured_output")(schema, **kwargs)

View File

@@ -0,0 +1,7 @@
from langchain.embeddings.base import init_embeddings
from langchain.embeddings.cache import CacheBackedEmbeddings
__all__ = [
"CacheBackedEmbeddings",
"init_embeddings",
]

View File

@@ -0,0 +1,235 @@
import functools
from importlib import util
from typing import Any, Optional, Union
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import Runnable
_SUPPORTED_PROVIDERS = {
"azure_openai": "langchain_openai",
"bedrock": "langchain_aws",
"cohere": "langchain_cohere",
"google_vertexai": "langchain_google_vertexai",
"huggingface": "langchain_huggingface",
"mistralai": "langchain_mistralai",
"ollama": "langchain_ollama",
"openai": "langchain_openai",
}
def _get_provider_list() -> str:
"""Get formatted list of providers and their packages."""
return "\n".join(
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
)
def _parse_model_string(model_name: str) -> tuple[str, str]:
"""Parse a model string into provider and model name components.
The model string should be in the format 'provider:model-name', where provider
is one of the supported providers.
Args:
model_name: A model string in the format 'provider:model-name'
Returns:
A tuple of (provider, model_name)
.. code-block:: python
_parse_model_string("openai:text-embedding-3-small")
# Returns: ("openai", "text-embedding-3-small")
_parse_model_string("bedrock:amazon.titan-embed-text-v1")
# Returns: ("bedrock", "amazon.titan-embed-text-v1")
Raises:
ValueError: If the model string is not in the correct format or
the provider is unsupported
"""
if ":" not in model_name:
providers = _SUPPORTED_PROVIDERS
msg = (
f"Invalid model format '{model_name}'.\n"
f"Model name must be in format 'provider:model-name'\n"
f"Example valid model strings:\n"
f" - openai:text-embedding-3-small\n"
f" - bedrock:amazon.titan-embed-text-v1\n"
f" - cohere:embed-english-v3.0\n"
f"Supported providers: {providers}"
)
raise ValueError(msg)
provider, model = model_name.split(":", 1)
provider = provider.lower().strip()
model = model.strip()
if provider not in _SUPPORTED_PROVIDERS:
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
if not model:
msg = "Model name cannot be empty"
raise ValueError(msg)
return provider, model
def _infer_model_and_provider(
model: str,
*,
provider: Optional[str] = None,
) -> tuple[str, str]:
if not model.strip():
msg = "Model name cannot be empty"
raise ValueError(msg)
if provider is None and ":" in model:
provider, model_name = _parse_model_string(model)
else:
model_name = model
if not provider:
providers = _SUPPORTED_PROVIDERS
msg = (
"Must specify either:\n"
"1. A model string in format 'provider:model-name'\n"
" Example: 'openai:text-embedding-3-small'\n"
"2. Or explicitly set provider from: "
f"{providers}"
)
raise ValueError(msg)
if provider not in _SUPPORTED_PROVIDERS:
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
return provider, model_name
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _check_pkg(pkg: str) -> None:
"""Check if a package is installed."""
if not util.find_spec(pkg):
msg = (
f"Could not import {pkg} python package. "
f"Please install it with `pip install {pkg}`"
)
raise ImportError(msg)
def init_embeddings(
model: str,
*,
provider: Optional[str] = None,
**kwargs: Any,
) -> Union[Embeddings, Runnable[Any, list[float]]]:
"""Initialize an embeddings model from a model name and optional provider.
**Note:** Must have the integration package corresponding to the model provider
installed.
Args:
model: Name of the model to use. Can be either:
- A model string like "openai:text-embedding-3-small"
- Just the model name if provider is specified
provider: Optional explicit provider name. If not specified,
will attempt to parse from the model string. Supported providers
and their required packages:
{_get_provider_list()}
**kwargs: Additional model-specific parameters passed to the embedding model.
These vary by provider, see the provider-specific documentation for details.
Returns:
An Embeddings instance that can generate embeddings for text.
Raises:
ValueError: If the model provider is not supported or cannot be determined
ImportError: If the required provider package is not installed
.. dropdown:: Example Usage
:open:
.. code-block:: python
# Using a model string
model = init_embeddings("openai:text-embedding-3-small")
model.embed_query("Hello, world!")
# Using explicit provider
model = init_embeddings(
model="text-embedding-3-small",
provider="openai"
)
model.embed_documents(["Hello, world!", "Goodbye, world!"])
# With additional parameters
model = init_embeddings(
"openai:text-embedding-3-small",
api_key="sk-..."
)
.. versionadded:: 0.3.9
"""
if not model:
providers = _SUPPORTED_PROVIDERS.keys()
msg = (
f"Must specify model name. Supported providers are: {', '.join(providers)}"
)
raise ValueError(msg)
provider, model_name = _infer_model_and_provider(model, provider=provider)
pkg = _SUPPORTED_PROVIDERS[provider]
_check_pkg(pkg)
if provider == "openai":
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(model=model_name, **kwargs)
if provider == "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings
return VertexAIEmbeddings(model=model_name, **kwargs)
if provider == "bedrock":
from langchain_aws import BedrockEmbeddings
return BedrockEmbeddings(model_id=model_name, **kwargs)
if provider == "cohere":
from langchain_cohere import CohereEmbeddings
return CohereEmbeddings(model=model_name, **kwargs)
if provider == "mistralai":
from langchain_mistralai import MistralAIEmbeddings
return MistralAIEmbeddings(model=model_name, **kwargs)
if provider == "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
if provider == "ollama":
from langchain_ollama import OllamaEmbeddings
return OllamaEmbeddings(model=model_name, **kwargs)
msg = (
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
raise ValueError(msg)
__all__ = [
"Embeddings", # This one is for backwards compatibility
"init_embeddings",
]

View File

@@ -0,0 +1,371 @@
"""Module contains code for a cache backed embedder.
The cache backed embedder is a wrapper around an embedder that caches
embeddings in a key-value store. The cache is used to avoid recomputing
embeddings for the same text.
The text is hashed and the hash is used as the key in the cache.
"""
from __future__ import annotations
import hashlib
import json
import uuid
import warnings
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast
from langchain_core.embeddings import Embeddings
from langchain_core.utils.iter import batch_iterate
from langchain.storage.encoder_backed import EncoderBackedStore
if TYPE_CHECKING:
from collections.abc import Sequence
from langchain_core.stores import BaseStore, ByteStore
NAMESPACE_UUID = uuid.UUID(int=1985)
def _sha1_hash_to_uuid(text: str) -> uuid.UUID:
"""Return a UUID derived from *text* using SHA-1 (deterministic).
Deterministic and fast, **but not collision-resistant**.
A malicious attacker could try to create two different texts that hash to the same
UUID. This may not necessarily be an issue in the context of caching embeddings,
but new applications should swap this out for a stronger hash function like
xxHash, BLAKE2 or SHA-256, which are collision-resistant.
"""
sha1_hex = hashlib.sha1(text.encode("utf-8"), usedforsecurity=False).hexdigest()
# Embed the hex string in `uuid5` to obtain a valid UUID.
return uuid.uuid5(NAMESPACE_UUID, sha1_hex)
def _make_default_key_encoder(namespace: str, algorithm: str) -> Callable[[str], str]:
"""Create a default key encoder function.
Args:
namespace: Prefix that segregates keys from different embedding models.
algorithm:
* ``'sha1'`` - fast but not collision-resistant
* ``'blake2b'`` - cryptographically strong, faster than SHA-1
* ``'sha256'`` - cryptographically strong, slower than SHA-1
* ``'sha512'`` - cryptographically strong, slower than SHA-1
Returns:
A function that encodes a key using the specified algorithm.
"""
if algorithm == "sha1":
_warn_about_sha1_encoder()
def _key_encoder(key: str) -> str:
"""Encode a key using the specified algorithm."""
if algorithm == "sha1":
return f"{namespace}{_sha1_hash_to_uuid(key)}"
if algorithm == "blake2b":
return f"{namespace}{hashlib.blake2b(key.encode('utf-8')).hexdigest()}"
if algorithm == "sha256":
return f"{namespace}{hashlib.sha256(key.encode('utf-8')).hexdigest()}"
if algorithm == "sha512":
return f"{namespace}{hashlib.sha512(key.encode('utf-8')).hexdigest()}"
msg = f"Unsupported algorithm: {algorithm}"
raise ValueError(msg)
return _key_encoder
def _value_serializer(value: Sequence[float]) -> bytes:
"""Serialize a value."""
return json.dumps(value).encode()
def _value_deserializer(serialized_value: bytes) -> list[float]:
"""Deserialize a value."""
return cast("list[float]", json.loads(serialized_value.decode()))
# The warning is global; track emission, so it appears only once.
_warned_about_sha1: bool = False
def _warn_about_sha1_encoder() -> None:
"""Emit a one-time warning about SHA-1 collision weaknesses."""
global _warned_about_sha1 # noqa: PLW0603
if not _warned_about_sha1:
warnings.warn(
"Using default key encoder: SHA-1 is *not* collision-resistant. "
"While acceptable for most cache scenarios, a motivated attacker "
"can craft two different payloads that map to the same cache key. "
"If that risk matters in your environment, supply a stronger "
"encoder (e.g. SHA-256 or BLAKE2) via the `key_encoder` argument. "
"If you change the key encoder, consider also creating a new cache, "
"to avoid (the potential for) collisions with existing keys.",
category=UserWarning,
stacklevel=2,
)
_warned_about_sha1 = True
class CacheBackedEmbeddings(Embeddings):
"""Interface for caching results from embedding models.
The interface allows works with any store that implements
the abstract store interface accepting keys of type str and values of list of
floats.
If need be, the interface can be extended to accept other implementations
of the value serializer and deserializer, as well as the key encoder.
Note that by default only document embeddings are cached. To cache query
embeddings too, pass in a query_embedding_store to constructor.
Examples:
.. code-block: python
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.embeddings import OpenAIEmbeddings
store = LocalFileStore('./my_cache')
underlying_embedder = OpenAIEmbeddings()
embedder = CacheBackedEmbeddings.from_bytes_store(
underlying_embedder, store, namespace=underlying_embedder.model
)
# Embedding is computed and cached
embeddings = embedder.embed_documents(["hello", "goodbye"])
# Embeddings are retrieved from the cache, no computation is done
embeddings = embedder.embed_documents(["hello", "goodbye"])
"""
def __init__(
self,
underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, list[float]],
*,
batch_size: Optional[int] = None,
query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
) -> None:
"""Initialize the embedder.
Args:
underlying_embeddings: the embedder to use for computing embeddings.
document_embedding_store: The store to use for caching document embeddings.
batch_size: The number of documents to embed between store updates.
query_embedding_store: The store to use for caching query embeddings.
If ``None``, query embeddings are not cached.
"""
super().__init__()
self.document_embedding_store = document_embedding_store
self.query_embedding_store = query_embedding_store
self.underlying_embeddings = underlying_embeddings
self.batch_size = batch_size
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts.
The method first checks the cache for the embeddings.
If the embeddings are not found, the method uses the underlying embedder
to embed the documents and stores the results in the cache.
Args:
texts: A list of texts to embed.
Returns:
A list of embeddings for the given texts.
"""
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
texts,
)
all_missing_indices: list[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
self.document_embedding_store.mset(
list(zip(missing_texts, missing_vectors)),
)
for index, updated_vector in zip(missing_indices, missing_vectors):
vectors[index] = updated_vector
return cast(
"list[list[float]]",
vectors,
) # Nones should have been resolved by now
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts.
The method first checks the cache for the embeddings.
If the embeddings are not found, the method uses the underlying embedder
to embed the documents and stores the results in the cache.
Args:
texts: A list of texts to embed.
Returns:
A list of embeddings for the given texts.
"""
vectors: list[
Union[list[float], None]
] = await self.document_embedding_store.amget(texts)
all_missing_indices: list[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
# batch_iterate supports None batch_size which returns all elements at once
# as a single batch.
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
missing_texts = [texts[i] for i in missing_indices]
missing_vectors = await self.underlying_embeddings.aembed_documents(
missing_texts,
)
await self.document_embedding_store.amset(
list(zip(missing_texts, missing_vectors)),
)
for index, updated_vector in zip(missing_indices, missing_vectors):
vectors[index] = updated_vector
return cast(
"list[list[float]]",
vectors,
) # Nones should have been resolved by now
def embed_query(self, text: str) -> list[float]:
"""Embed query text.
By default, this method does not cache queries. To enable caching, set the
``cache_query`` parameter to ``True`` when initializing the embedder.
Args:
text: The text to embed.
Returns:
The embedding for the given text.
"""
if not self.query_embedding_store:
return self.underlying_embeddings.embed_query(text)
(cached,) = self.query_embedding_store.mget([text])
if cached is not None:
return cached
vector = self.underlying_embeddings.embed_query(text)
self.query_embedding_store.mset([(text, vector)])
return vector
async def aembed_query(self, text: str) -> list[float]:
"""Embed query text.
By default, this method does not cache queries. To enable caching, set the
``cache_query`` parameter to ``True`` when initializing the embedder.
Args:
text: The text to embed.
Returns:
The embedding for the given text.
"""
if not self.query_embedding_store:
return await self.underlying_embeddings.aembed_query(text)
(cached,) = await self.query_embedding_store.amget([text])
if cached is not None:
return cached
vector = await self.underlying_embeddings.aembed_query(text)
await self.query_embedding_store.amset([(text, vector)])
return vector
@classmethod
def from_bytes_store(
cls,
underlying_embeddings: Embeddings,
document_embedding_cache: ByteStore,
*,
namespace: str = "",
batch_size: Optional[int] = None,
query_embedding_cache: Union[bool, ByteStore] = False,
key_encoder: Union[
Callable[[str], str],
Literal["sha1", "blake2b", "sha256", "sha512"],
] = "sha1",
) -> CacheBackedEmbeddings:
"""On-ramp that adds the necessary serialization and encoding to the store.
Args:
underlying_embeddings: The embedder to use for embedding.
document_embedding_cache: The cache to use for storing document embeddings.
*,
namespace: The namespace to use for document cache.
This namespace is used to avoid collisions with other caches.
For example, set it to the name of the embedding model used.
batch_size: The number of documents to embed between store updates.
query_embedding_cache: The cache to use for storing query embeddings.
True to use the same cache as document embeddings.
False to not cache query embeddings.
key_encoder: Optional callable to encode keys. If not provided,
a default encoder using SHA-1 will be used. SHA-1 is not
collision-resistant, and a motivated attacker could craft two
different texts that hash to the same cache key.
New applications should use one of the alternative encoders
or provide a custom and strong key encoder function to avoid this risk.
If you change a key encoder in an existing cache, consider
just creating a new cache, to avoid (the potential for)
collisions with existing keys or having duplicate keys
for the same text in the cache.
Returns:
An instance of CacheBackedEmbeddings that uses the provided cache.
"""
if isinstance(key_encoder, str):
key_encoder = _make_default_key_encoder(namespace, key_encoder)
elif callable(key_encoder):
# If a custom key encoder is provided, it should not be used with a
# namespace.
# A user can handle namespacing in directly their custom key encoder.
if namespace:
msg = (
"Do not supply `namespace` when using a custom key_encoder; "
"add any prefixing inside the encoder itself."
)
raise ValueError(msg)
else:
msg = (
"key_encoder must be either 'blake2b', 'sha1', 'sha256', 'sha512' "
"or a callable that encodes keys."
)
raise ValueError(msg) # noqa: TRY004
document_embedding_store = EncoderBackedStore[str, list[float]](
document_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)
if query_embedding_cache is True:
query_embedding_store = document_embedding_store
elif query_embedding_cache is False:
query_embedding_store = None
else:
query_embedding_store = EncoderBackedStore[str, list[float]](
query_embedding_cache,
key_encoder,
_value_serializer,
_value_deserializer,
)
return cls(
underlying_embeddings,
document_embedding_store,
batch_size=batch_size,
query_embedding_store=query_embedding_store,
)

View File

@@ -0,0 +1,15 @@
"""Global values and configuration that apply to all of LangChain."""
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from langchain_core.caches import BaseCache
# DO NOT USE THESE VALUES DIRECTLY!
# Use them only via `get_<X>()` and `set_<X>()` below,
# or else your code may behave unexpectedly with other uses of these global settings:
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
_verbose: bool = False
_debug: bool = False
_llm_cache: Optional["BaseCache"] = None

View File

@@ -0,0 +1,16 @@
"""TBD: This module should provide high level building blocks for memory management.
We may want to wait until we combine:
1. langmem
2. some basic functions for message summarization
"""
from langchain_core.messages import filter_messages, trim_messages
from langchain_core.messages.utils import count_tokens_approximately
__all__ = [
"count_tokens_approximately",
"filter_messages",
"trim_messages",
]

View File

@@ -0,0 +1,39 @@
from langchain_core.example_selectors import (
LengthBasedExampleSelector,
MaxMarginalRelevanceExampleSelector,
SemanticSimilarityExampleSelector,
)
from langchain_core.prompts import (
AIMessagePromptTemplate,
BaseChatPromptTemplate,
BasePromptTemplate,
ChatMessagePromptTemplate,
ChatPromptTemplate,
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
FewShotPromptWithTemplates,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
StringPromptTemplate,
SystemMessagePromptTemplate,
)
__all__ = [
"AIMessagePromptTemplate",
"BaseChatPromptTemplate",
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"FewShotChatMessagePromptTemplate",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"HumanMessagePromptTemplate",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"MessagesPlaceholder",
"PromptTemplate",
"SemanticSimilarityExampleSelector",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
]

View File

View File

@@ -0,0 +1,27 @@
"""Implementations of key-value stores and storage helpers.
Module provides implementations of various key-value stores that conform
to a simple key-value interface.
The primary goal of these storages is to support implementation of caching.
"""
from langchain_core.stores import (
InMemoryByteStore,
InMemoryStore,
InvalidKeyException,
)
from langchain.storage._lc_store import create_kv_docstore, create_lc_store
from langchain.storage.encoder_backed import EncoderBackedStore
from langchain.storage.file_system import LocalFileStore
__all__ = [
"EncoderBackedStore",
"InMemoryByteStore",
"InMemoryStore",
"InvalidKeyException",
"LocalFileStore",
"create_kv_docstore",
"create_lc_store",
]

View File

@@ -0,0 +1,91 @@
"""Create a key-value store for any langchain serializable object."""
from typing import Callable, Optional
from langchain_core.documents import Document
from langchain_core.load import Serializable, dumps, loads
from langchain_core.stores import BaseStore, ByteStore
from langchain.storage.encoder_backed import EncoderBackedStore
def _dump_as_bytes(obj: Serializable) -> bytes:
"""Return a bytes representation of a document."""
return dumps(obj).encode("utf-8")
def _dump_document_as_bytes(obj: Document) -> bytes:
"""Return a bytes representation of a document."""
if not isinstance(obj, Document):
msg = "Expected a Document instance"
raise TypeError(msg)
return dumps(obj).encode("utf-8")
def _load_document_from_bytes(serialized: bytes) -> Document:
"""Return a document from a bytes representation."""
obj = loads(serialized.decode("utf-8"))
if not isinstance(obj, Document):
msg = f"Expected a Document instance. Got {type(obj)}"
raise TypeError(msg)
return obj
def _load_from_bytes(serialized: bytes) -> Serializable:
"""Return a document from a bytes representation."""
return loads(serialized.decode("utf-8"))
def _identity(x: str) -> str:
"""Return the same object."""
return x
# PUBLIC API
def create_lc_store(
store: ByteStore,
*,
key_encoder: Optional[Callable[[str], str]] = None,
) -> BaseStore[str, Serializable]:
"""Create a store for langchain serializable objects from a bytes store.
Args:
store: A bytes store to use as the underlying store.
key_encoder: A function to encode keys; if None uses identity function.
Returns:
A key-value store for documents.
"""
return EncoderBackedStore(
store,
key_encoder or _identity,
_dump_as_bytes,
_load_from_bytes,
)
def create_kv_docstore(
store: ByteStore,
*,
key_encoder: Optional[Callable[[str], str]] = None,
) -> BaseStore[str, Document]:
"""Create a store for langchain Document objects from a bytes store.
This store does run time type checking to ensure that the values are
Document objects.
Args:
store: A bytes store to use as the underlying store.
key_encoder: A function to encode keys; if None uses identity function.
Returns:
A key-value store for documents.
"""
return EncoderBackedStore(
store,
key_encoder or _identity,
_dump_document_as_bytes,
_load_document_from_bytes,
)

View File

@@ -0,0 +1,127 @@
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
)
from langchain_core.stores import BaseStore
K = TypeVar("K")
V = TypeVar("V")
class EncoderBackedStore(BaseStore[K, V]):
"""Wraps a store with key and value encoders/decoders.
Examples that uses JSON for encoding/decoding:
.. code-block:: python
import json
def key_encoder(key: int) -> str:
return json.dumps(key)
def value_serializer(value: float) -> str:
return json.dumps(value)
def value_deserializer(serialized_value: str) -> float:
return json.loads(serialized_value)
# Create an instance of the abstract store
abstract_store = MyCustomStore()
# Create an instance of the encoder-backed store
store = EncoderBackedStore(
store=abstract_store,
key_encoder=key_encoder,
value_serializer=value_serializer,
value_deserializer=value_deserializer
)
# Use the encoder-backed store methods
store.mset([(1, 3.14), (2, 2.718)])
values = store.mget([1, 2]) # Retrieves [3.14, 2.718]
store.mdelete([1, 2]) # Deletes the keys 1 and 2
"""
def __init__(
self,
store: BaseStore[str, Any],
key_encoder: Callable[[K], str],
value_serializer: Callable[[V], bytes],
value_deserializer: Callable[[Any], V],
) -> None:
"""Initialize an EncodedStore."""
self.store = store
self.key_encoder = key_encoder
self.value_serializer = value_serializer
self.value_deserializer = value_deserializer
def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
"""Get the values associated with the given keys."""
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
values = self.store.mget(encoded_keys)
return [
self.value_deserializer(value) if value is not None else value
for value in values
]
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
"""Get the values associated with the given keys."""
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
values = await self.store.amget(encoded_keys)
return [
self.value_deserializer(value) if value is not None else value
for value in values
]
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Set the values for the given keys."""
encoded_pairs = [
(self.key_encoder(key), self.value_serializer(value))
for key, value in key_value_pairs
]
self.store.mset(encoded_pairs)
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
"""Set the values for the given keys."""
encoded_pairs = [
(self.key_encoder(key), self.value_serializer(value))
for key, value in key_value_pairs
]
await self.store.amset(encoded_pairs)
def mdelete(self, keys: Sequence[K]) -> None:
"""Delete the given keys and their associated values."""
encoded_keys = [self.key_encoder(key) for key in keys]
self.store.mdelete(encoded_keys)
async def amdelete(self, keys: Sequence[K]) -> None:
"""Delete the given keys and their associated values."""
encoded_keys = [self.key_encoder(key) for key in keys]
await self.store.amdelete(encoded_keys)
def yield_keys(
self,
*,
prefix: Optional[str] = None,
) -> Union[Iterator[K], Iterator[str]]:
"""Get an iterator over keys that match the given prefix."""
# For the time being this does not return K, but str
# it's for debugging purposes. Should fix this.
yield from self.store.yield_keys(prefix=prefix)
async def ayield_keys(
self,
*,
prefix: Optional[str] = None,
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
"""Get an iterator over keys that match the given prefix."""
# For the time being this does not return K, but str
# it's for debugging purposes. Should fix this.
async for key in self.store.ayield_keys(prefix=prefix):
yield key

View File

@@ -0,0 +1,3 @@
from langchain_core.stores import InvalidKeyException
__all__ = ["InvalidKeyException"]

View File

@@ -0,0 +1,176 @@
import os
import re
import time
from collections.abc import Iterator, Sequence
from pathlib import Path
from typing import Optional, Union
from langchain_core.stores import ByteStore
from langchain.storage.exceptions import InvalidKeyException
class LocalFileStore(ByteStore):
"""BaseStore interface that works on the local file system.
Examples:
Create a LocalFileStore instance and perform operations on it:
.. code-block:: python
from langchain.storage import LocalFileStore
# Instantiate the LocalFileStore with the root path
file_store = LocalFileStore("/path/to/root")
# Set values for keys
file_store.mset([("key1", b"value1"), ("key2", b"value2")])
# Get values for keys
values = file_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
# Delete keys
file_store.mdelete(["key1"])
# Iterate over keys
for key in file_store.yield_keys():
print(key) # noqa: T201
"""
def __init__(
self,
root_path: Union[str, Path],
*,
chmod_file: Optional[int] = None,
chmod_dir: Optional[int] = None,
update_atime: bool = False,
) -> None:
"""Implement the BaseStore interface for the local file system.
Args:
root_path (Union[str, Path]): The root path of the file store. All keys are
interpreted as paths relative to this root.
chmod_file: (optional, defaults to `None`) If specified, sets permissions
for newly created files, overriding the current `umask` if needed.
chmod_dir: (optional, defaults to `None`) If specified, sets permissions
for newly created dirs, overriding the current `umask` if needed.
update_atime: (optional, defaults to `False`) If `True`, updates the
filesystem access time (but not the modified time) when a file is read.
This allows MRU/LRU cache policies to be implemented for filesystems
where access time updates are disabled.
"""
self.root_path = Path(root_path).absolute()
self.chmod_file = chmod_file
self.chmod_dir = chmod_dir
self.update_atime = update_atime
def _get_full_path(self, key: str) -> Path:
"""Get the full path for a given key relative to the root path.
Args:
key (str): The key relative to the root path.
Returns:
Path: The full path for the given key.
"""
if not re.match(r"^[a-zA-Z0-9_.\-/]+$", key):
msg = f"Invalid characters in key: {key}"
raise InvalidKeyException(msg)
full_path = (self.root_path / key).resolve()
root_path = self.root_path.resolve()
common_path = os.path.commonpath([root_path, full_path])
if common_path != str(root_path):
msg = (
f"Invalid key: {key}. Key should be relative to the full path. "
f"{root_path} vs. {common_path} and full path of {full_path}"
)
raise InvalidKeyException(msg)
return full_path
def _mkdir_for_store(self, dir_path: Path) -> None:
"""Makes a store directory path (including parents) with specified permissions.
This is needed because `Path.mkdir()` is restricted by the current `umask`,
whereas the explicit `os.chmod()` used here is not.
Args:
dir_path: (Path) The store directory to make
Returns:
None
"""
if not dir_path.exists():
self._mkdir_for_store(dir_path.parent)
dir_path.mkdir(exist_ok=True)
if self.chmod_dir is not None:
dir_path.chmod(self.chmod_dir)
def mget(self, keys: Sequence[str]) -> list[Optional[bytes]]:
"""Get the values associated with the given keys.
Args:
keys: A sequence of keys.
Returns:
A sequence of optional values associated with the keys.
If a key is not found, the corresponding value will be None.
"""
values: list[Optional[bytes]] = []
for key in keys:
full_path = self._get_full_path(key)
if full_path.exists():
value = full_path.read_bytes()
values.append(value)
if self.update_atime:
# update access time only; preserve modified time
os.utime(full_path, (time.time(), full_path.stat().st_mtime))
else:
values.append(None)
return values
def mset(self, key_value_pairs: Sequence[tuple[str, bytes]]) -> None:
"""Set the values for the given keys.
Args:
key_value_pairs: A sequence of key-value pairs.
Returns:
None
"""
for key, value in key_value_pairs:
full_path = self._get_full_path(key)
self._mkdir_for_store(full_path.parent)
full_path.write_bytes(value)
if self.chmod_file is not None:
full_path.chmod(self.chmod_file)
def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given keys and their associated values.
Args:
keys (Sequence[str]): A sequence of keys to delete.
Returns:
None
"""
for key in keys:
full_path = self._get_full_path(key)
if full_path.exists():
full_path.unlink()
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Get an iterator over keys that match the given prefix.
Args:
prefix (Optional[str]): The prefix to match.
Returns:
Iterator[str]: An iterator over keys that match the given prefix.
"""
prefix_path = self._get_full_path(prefix) if prefix else self.root_path
for file in prefix_path.rglob("*"):
if file.is_file():
relative_path = file.relative_to(self.root_path)
yield str(relative_path)

View File

@@ -0,0 +1,13 @@
"""In memory store that is not thread safe and has no eviction policy.
This is a simple implementation of the BaseStore using a dictionary that is useful
primarily for unit testing purposes.
"""
from langchain_core.stores import InMemoryBaseStore, InMemoryByteStore, InMemoryStore
__all__ = [
"InMemoryBaseStore",
"InMemoryByteStore",
"InMemoryStore",
]

View File

@@ -0,0 +1,50 @@
"""Kept for backwards compatibility."""
from langchain_text_splitters import (
Language,
RecursiveCharacterTextSplitter,
TextSplitter,
Tokenizer,
TokenTextSplitter,
)
from langchain_text_splitters.base import split_text_on_tokens
from langchain_text_splitters.character import CharacterTextSplitter
from langchain_text_splitters.html import ElementType, HTMLHeaderTextSplitter
from langchain_text_splitters.json import RecursiveJsonSplitter
from langchain_text_splitters.konlpy import KonlpyTextSplitter
from langchain_text_splitters.latex import LatexTextSplitter
from langchain_text_splitters.markdown import (
HeaderType,
LineType,
MarkdownHeaderTextSplitter,
MarkdownTextSplitter,
)
from langchain_text_splitters.nltk import NLTKTextSplitter
from langchain_text_splitters.python import PythonCodeTextSplitter
from langchain_text_splitters.sentence_transformers import (
SentenceTransformersTokenTextSplitter,
)
from langchain_text_splitters.spacy import SpacyTextSplitter
__all__ = [
"CharacterTextSplitter",
"ElementType",
"HTMLHeaderTextSplitter",
"HeaderType",
"KonlpyTextSplitter",
"Language",
"LatexTextSplitter",
"LineType",
"MarkdownHeaderTextSplitter",
"MarkdownTextSplitter",
"NLTKTextSplitter",
"PythonCodeTextSplitter",
"RecursiveCharacterTextSplitter",
"RecursiveJsonSplitter",
"SentenceTransformersTokenTextSplitter",
"SpacyTextSplitter",
"TextSplitter",
"TokenTextSplitter",
"Tokenizer",
"split_text_on_tokens",
]

View File

@@ -0,0 +1,17 @@
from langchain_core.tools import (
BaseTool,
InjectedToolArg,
InjectedToolCallId,
Tool,
ToolException,
tool,
)
__all__ = [
"BaseTool",
"InjectedToolArg",
"InjectedToolCallId",
"Tool",
"ToolException",
"tool",
]

View File

@@ -0,0 +1,190 @@
[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
[project]
authors = []
license = { text = "MIT" }
requires-python = ">=3.9, <4.0"
dependencies = [
"langchain-core<1.0.0,>=0.3.66",
"langchain-text-splitters<1.0.0,>=0.3.8",
"langgraph>=0.5.4",
"pydantic>=2.7.4",
]
name = "langchain"
version = "1.0.0dev1"
description = "Building applications with LLMs through composability"
readme = "README.md"
[project.optional-dependencies]
# community = ["langchain-community"]
anthropic = ["langchain-anthropic"]
openai = ["langchain-openai"]
azure-ai = ["langchain-azure-ai"]
# cohere = ["langchain-cohere"]
google-vertexai = ["langchain-google-vertexai"]
google-genai = ["langchain-google-genai"]
fireworks = ["langchain-fireworks"]
ollama = ["langchain-ollama"]
together = ["langchain-together"]
mistralai = ["langchain-mistralai"]
huggingface = ["langchain-huggingface"]
groq = ["langchain-groq"]
aws = ["langchain-aws"]
deepseek = ["langchain-deepseek"]
xai = ["langchain-xai"]
perplexity = ["langchain-perplexity"]
[project.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/langchain"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain%3D%3D0%22&expanded=true"
repository = "https://github.com/langchain-ai/langchain"
[dependency-groups]
test = [
"pytest<9,>=8",
"pytest-cov<5.0.0,>=4.0.0",
"pytest-watcher<1.0.0,>=0.2.6",
"pytest-asyncio<1.0.0,>=0.23.2",
"pytest-socket<1.0.0,>=0.6.0",
"syrupy<5.0.0,>=4.0.2",
"pytest-xdist<4.0.0,>=3.6.1",
"blockbuster<1.6,>=1.5.18",
"langchain-tests",
"langchain-core",
"langchain-text-splitters",
"langchain-openai",
"toml>=0.10.2",
]
codespell = ["codespell<3.0.0,>=2.2.0"]
lint = [
"ruff<0.13,>=0.12.2",
"mypy<1.16,>=1.15",
]
typing = [
"types-toml>=0.10.8.20240310",
]
test_integration = [
"vcrpy>=7.0",
"urllib3<2; python_version < \"3.10\"",
"wrapt<2.0.0,>=1.15.0",
"python-dotenv<2.0.0,>=1.0.0",
"cassio<1.0.0,>=0.1.0",
"langchainhub<1.0.0,>=0.1.16",
"langchain-core",
"langchain-text-splitters",
]
[tool.uv.sources]
langchain-core = { path = "../core", editable = true }
langchain-tests = { path = "../standard-tests", editable = true }
langchain-text-splitters = { path = "../text-splitters", editable = true }
langchain-openai = { path = "../partners/openai", editable = true }
[tool.ruff]
target-version = "py39"
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
[tool.mypy]
strict = "True"
strict_bytes = "True"
ignore_missing_imports = "True"
enable_error_code = "deprecated"
report_deprecated_as_note = "True"
# TODO: activate for 'strict' checking
disallow_untyped_calls = "False"
disallow_any_generics = "False"
disallow_untyped_decorators = "False"
warn_return_any = "False"
strict_equality = "False"
[tool.codespell]
skip = ".git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,*.trig"
ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint]
select = [
"ALL"
]
ignore = [
"D100", # pydocstyle: Missing docstring in public module
"D104", # pydocstyle: Missing docstring in public package
"D105", # pydocstyle: Missing docstring in magic method
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
"PLC0415", # Imports should be at the top. Not always desirable
"PLR0913", # Too many arguments in function definition
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true
flake8-annotations.allow-star-arg-any = true
[tool.ruff.lint.per-file-ignores]
"tests/*" = [
"D", # Documentation rules
"PLC0415", # Imports should be at the top. Not always desirable for tests
]
[tool.ruff.lint.extend-per-file-ignores]
"scripts/check_imports.py" = ["ALL"]
"langchain/globals.py" = [
"PLW"
]
"langchain/chat_models/base.py" = [
"ANN",
"C901",
"FIX002",
"N802",
"PLR0911",
"PLR0912",
"PLR0915",
]
"langchain/embeddings/base.py" = [
"PLR0911",
"PLR0913",
]
"tests/**/*.py" = [
"S101", # Tests need assertions
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"SLF001", # Private member access in tests
"PLR2004", # Magic values are perfectly fine in unit tests (e.g. 0, 1, 2, etc.)
"C901", # Too complex
"ANN401", # Annotated type is not necessary
"N802", # Function name should be lowercase
"PLW1641", # Object does not implement __hash__ method
"ARG002", # Unused argument
"BLE001", # Do not catch blind exception
"N801", # class name should use CapWords convention
]
[tool.coverage.run]
omit = ["tests/*"]
[tool.pytest.ini_options]
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
markers = [
"requires: mark tests as requiring a specific library",
"scheduled: mark tests to run in scheduled testing",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"
filterwarnings = [
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
"ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:tests",
"ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:tests",
]

View File

@@ -0,0 +1,31 @@
"""Quickly verify that a list of Python files can be loaded by the Python interpreter
without raising any errors. Ran before running more expensive tests. Useful in
Makefiles.
If loading a file fails, the script prints the problematic filename and the detailed
error traceback.
"""
import random
import string
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
module_name = "".join(
random.choice(string.ascii_letters) # noqa: S311
for _ in range(20)
)
SourceFileLoader(module_name, file).load_module()
except Exception:
has_failure = True
print(file) # noqa: T201
traceback.print_exc()
print() # noqa: T201
sys.exit(1 if has_failure else 0)

View File

@@ -0,0 +1 @@
"""All tests for this package."""

View File

@@ -0,0 +1 @@
"""All integration tests (tests that call out to an external API)."""

View File

@@ -0,0 +1 @@
"""All integration tests for Cache objects."""

View File

@@ -0,0 +1,81 @@
"""Fake Embedding class for testing purposes."""
import math
from langchain_core.embeddings import Embeddings
fake_texts = ["foo", "bar", "baz"]
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[1.0] * 9 + [float(i)] for i in range(len(texts))]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> list[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [1.0] * 9 + [0.0]
async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)
class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""
def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: list[str] = []
self.dimensionality = dimensionality
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [1.0] * (self.dimensionality - 1) + [
float(self.known_texts.index(text)),
]
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> list[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]
class AngularTwoDimensionalEmbeddings(Embeddings):
"""
From angles (as strings in units of pi) to unit embedding vectors on a circle.
"""
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""
Make a list of texts into a list of embedding vectors.
"""
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> list[float]:
"""
Convert input text to a 'vector' (list of floats).
If the text is a number, use it as the angle for the
unit vector in units of pi.
Any other input text becomes the singular result [0, 0] !
"""
try:
angle = float(text)
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
except ValueError:
# Assume: just test string, no attention is paid to values.
return [0.0, 0.0]

View File

@@ -0,0 +1,59 @@
from typing import cast
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_tests.integration_tests import ChatModelIntegrationTests
from pydantic import BaseModel
from langchain.chat_models import init_chat_model
class multiply(BaseModel):
"""Product of two ints."""
x: int
y: int
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
async def test_init_chat_model_chain() -> None:
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
model_with_tools = model.bind_tools([multiply])
model_with_config = model_with_tools.with_config(
RunnableConfig(tags=["foo"]),
configurable={"bar_model": "claude-3-sonnet-20240229"},
)
prompt = ChatPromptTemplate.from_messages([("system", "foo"), ("human", "{input}")])
chain = prompt | model_with_config
output = chain.invoke({"input": "bar"})
assert isinstance(output, AIMessage)
events = [
event async for event in chain.astream_events({"input": "bar"}, version="v2")
]
assert events
class TestStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> type[BaseChatModel]:
return cast("type[BaseChatModel]", init_chat_model)
@property
def chat_model_params(self) -> dict:
return {"model": "gpt-4o", "configurable_fields": "any"}
@property
def supports_image_inputs(self) -> bool:
return True
@property
def has_tool_calling(self) -> bool:
return True
@property
def has_structured_output(self) -> bool:
return True

View File

@@ -0,0 +1,34 @@
from pathlib import Path
import pytest
# Getting the absolute path of the current file's directory
ABS_PATH = Path(__file__).resolve().parent
# Getting the absolute path of the project's root directory
PROJECT_DIR = ABS_PATH.parent.parent
# Loading the .env file if it exists
def _load_env() -> None:
dotenv_path = PROJECT_DIR / "tests" / "integration_tests" / ".env"
if dotenv_path.exists():
from dotenv import load_dotenv
load_dotenv(dotenv_path)
_load_env()
@pytest.fixture(scope="module")
def test_dir() -> Path:
return PROJECT_DIR / "tests" / "integration_tests"
# This fixture returns a string containing the path to the cassette directory for the
# current module
@pytest.fixture(scope="module")
def vcr_cassette_dir(request: pytest.FixtureRequest) -> str:
module = Path(request.module.__file__)
return str(module.parent / "cassettes" / module.stem)

View File

@@ -0,0 +1,44 @@
"""Test embeddings base module."""
import importlib
import pytest
from langchain_core.embeddings import Embeddings
from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings
@pytest.mark.parametrize(
("provider", "model"),
[
("openai", "text-embedding-3-large"),
("google_vertexai", "text-embedding-gecko@003"),
("bedrock", "amazon.titan-embed-text-v1"),
("cohere", "embed-english-v2.0"),
],
)
async def test_init_embedding_model(provider: str, model: str) -> None:
package = _SUPPORTED_PROVIDERS[provider]
try:
importlib.import_module(package)
except ImportError:
pytest.skip(f"Package {package} is not installed")
model_colon = init_embeddings(f"{provider}:{model}")
assert isinstance(model_colon, Embeddings)
model_explicit = init_embeddings(
model=model,
provider=provider,
)
assert isinstance(model_explicit, Embeddings)
text = "Hello world"
embedding_colon = await model_colon.aembed_query(text)
assert isinstance(embedding_colon, list)
assert all(isinstance(x, float) for x in embedding_colon)
embedding_explicit = await model_explicit.aembed_query(text)
assert isinstance(embedding_explicit, list)
assert all(isinstance(x, float) for x in embedding_explicit)

View File

@@ -0,0 +1,6 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""

View File

@@ -0,0 +1,236 @@
import os
from typing import TYPE_CHECKING, Optional
from unittest import mock
import pytest
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableSequence
from pydantic import SecretStr
from langchain.chat_models import __all__, init_chat_model
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
EXPECTED_ALL = [
"init_chat_model",
]
def test_all_imports() -> None:
"""Test that all expected imports are present in the module's __all__."""
assert set(__all__) == set(EXPECTED_ALL)
@pytest.mark.requires(
"langchain_openai",
"langchain_anthropic",
"langchain_fireworks",
"langchain_groq",
)
@pytest.mark.parametrize(
("model_name", "model_provider"),
[
("gpt-4o", "openai"),
("claude-3-opus-20240229", "anthropic"),
("accounts/fireworks/models/mixtral-8x7b-instruct", "fireworks"),
("mixtral-8x7b-32768", "groq"),
],
)
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
llm1: BaseChatModel = init_chat_model(
model_name,
model_provider=model_provider,
api_key="foo",
)
llm2: BaseChatModel = init_chat_model(
f"{model_provider}:{model_name}",
api_key="foo",
)
assert llm1.dict() == llm2.dict()
def test_init_missing_dep() -> None:
with pytest.raises(ImportError):
init_chat_model("mixtral-8x7b-32768", model_provider="groq")
def test_init_unknown_provider() -> None:
with pytest.raises(ValueError, match="Unsupported model_provider='bar'."):
init_chat_model("foo", model_provider="bar")
@pytest.mark.requires("langchain_openai")
@mock.patch.dict(
os.environ,
{"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"},
clear=True,
)
def test_configurable() -> None:
model = init_chat_model()
for method in (
"invoke",
"ainvoke",
"batch",
"abatch",
"stream",
"astream",
"batch_as_completed",
"abatch_as_completed",
):
assert hasattr(model, method)
# Doesn't have access non-configurable, non-declarative methods until a config is
# provided.
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
with pytest.raises(AttributeError):
getattr(model, method)
# Can call declarative methods even without a default model.
model_with_tools = model.bind_tools(
[{"name": "foo", "description": "foo", "parameters": {}}],
)
# Check that original model wasn't mutated by declarative operation.
assert model._queued_declarative_operations == []
# Can iteratively call declarative methods.
model_with_config = model_with_tools.with_config(
RunnableConfig(tags=["foo"]),
configurable={"model": "gpt-4o"},
)
assert model_with_config.model_name == "gpt-4o" # type: ignore[attr-defined]
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
assert hasattr(model_with_config, method)
assert model_with_config.model_dump() == { # type: ignore[attr-defined]
"name": None,
"bound": {
"name": None,
"disable_streaming": False,
"disabled_params": None,
"model_name": "gpt-4o",
"temperature": None,
"model_kwargs": {},
"openai_api_key": SecretStr("foo"),
"openai_api_base": None,
"openai_organization": None,
"openai_proxy": None,
"output_version": "v0",
"request_timeout": None,
"max_retries": None,
"presence_penalty": None,
"reasoning": None,
"reasoning_effort": None,
"frequency_penalty": None,
"include": None,
"seed": None,
"service_tier": None,
"logprobs": None,
"top_logprobs": None,
"logit_bias": None,
"streaming": False,
"n": None,
"top_p": None,
"truncation": None,
"max_tokens": None,
"tiktoken_model_name": None,
"default_headers": None,
"default_query": None,
"stop": None,
"store": None,
"extra_body": None,
"include_response_headers": False,
"stream_usage": False,
"use_previous_response_id": False,
"use_responses_api": None,
},
"kwargs": {
"tools": [
{
"type": "function",
"function": {"name": "foo", "description": "foo", "parameters": {}},
},
],
},
"config": {"tags": ["foo"], "configurable": {}},
"config_factories": [],
"custom_input_type": None,
"custom_output_type": None,
}
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
@mock.patch.dict(
os.environ,
{"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"},
clear=True,
)
def test_configurable_with_default() -> None:
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
for method in (
"invoke",
"ainvoke",
"batch",
"abatch",
"stream",
"astream",
"batch_as_completed",
"abatch_as_completed",
):
assert hasattr(model, method)
# Does have access non-configurable, non-declarative methods since default params
# are provided.
for method in ("get_num_tokens", "get_num_tokens_from_messages", "dict"):
assert hasattr(model, method)
assert model.model_name == "gpt-4o"
model_with_tools = model.bind_tools(
[{"name": "foo", "description": "foo", "parameters": {}}],
)
model_with_config = model_with_tools.with_config(
RunnableConfig(tags=["foo"]),
configurable={"bar_model": "claude-3-sonnet-20240229"},
)
assert model_with_config.model == "claude-3-sonnet-20240229" # type: ignore[attr-defined]
assert model_with_config.model_dump() == { # type: ignore[attr-defined]
"name": None,
"bound": {
"name": None,
"disable_streaming": False,
"model": "claude-3-sonnet-20240229",
"mcp_servers": None,
"max_tokens": 1024,
"temperature": None,
"thinking": None,
"top_k": None,
"top_p": None,
"default_request_timeout": None,
"max_retries": 2,
"stop_sequences": None,
"anthropic_api_url": "https://api.anthropic.com",
"anthropic_api_key": SecretStr("bar"),
"betas": None,
"default_headers": None,
"model_kwargs": {},
"streaming": False,
"stream_usage": True,
},
"kwargs": {
"tools": [{"name": "foo", "description": "foo", "input_schema": {}}],
},
"config": {"tags": ["foo"], "configurable": {}},
"config_factories": [],
"custom_input_type": None,
"custom_output_type": None,
}
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
chain = prompt | model_with_config
assert isinstance(chain, RunnableSequence)

View File

@@ -0,0 +1,122 @@
"""Configuration for unit tests."""
from collections.abc import Iterator, Sequence
from importlib import util
import pytest
from blockbuster import blockbuster_ctx
@pytest.fixture(autouse=True)
def blockbuster() -> Iterator[None]:
with blockbuster_ctx("langchain") as bb:
bb.functions["io.TextIOWrapper.read"].can_block_in(
"langchain/__init__.py",
"<module>",
)
for func in ["os.stat", "os.path.abspath"]:
(
bb.functions[func]
.can_block_in("langchain_core/runnables/base.py", "__repr__")
.can_block_in(
"langchain_core/beta/runnables/context.py",
"aconfig_with_context",
)
)
for func in ["os.stat", "io.TextIOWrapper.read"]:
bb.functions[func].can_block_in(
"langsmith/client.py",
"_default_retry_config",
)
for bb_function in bb.functions.values():
bb_function.can_block_in(
"freezegun/api.py",
"_get_cached_module_attributes",
)
yield
def pytest_addoption(parser: pytest.Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(
config: pytest.Config, items: Sequence[pytest.Function]
) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
msg = "Cannot specify both `--only-extended` and `--only-core`."
raise ValueError(msg)
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"),
)
break
elif only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test."),
)

View File

@@ -0,0 +1,111 @@
"""Test embeddings base module."""
import pytest
from langchain.embeddings.base import (
_SUPPORTED_PROVIDERS,
_infer_model_and_provider,
_parse_model_string,
)
def test_parse_model_string() -> None:
"""Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
)
def test_parse_model_string_errors() -> None:
"""Test error cases for model string parsing."""
with pytest.raises(ValueError, match="Model name must be"):
_parse_model_string("just-a-model-name")
with pytest.raises(ValueError, match="Invalid model format "):
_parse_model_string("")
with pytest.raises(ValueError, match="is not supported"):
_parse_model_string(":model-name")
with pytest.raises(ValueError, match="Model name cannot be empty"):
_parse_model_string("openai:")
with pytest.raises(
ValueError,
match="Provider 'invalid-provider' is not supported",
):
_parse_model_string("invalid-provider:model-name")
for provider in _SUPPORTED_PROVIDERS:
with pytest.raises(ValueError, match=f"{provider}"):
_parse_model_string("invalid-provider:model-name")
def test_infer_model_and_provider() -> None:
"""Test model and provider inference from different input formats."""
assert _infer_model_and_provider("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _infer_model_and_provider(
model="text-embedding-3-small",
provider="openai",
) == ("openai", "text-embedding-3-small")
assert _infer_model_and_provider(
model="ft:text-embedding-3-small",
provider="openai",
) == ("openai", "ft:text-embedding-3-small")
assert _infer_model_and_provider(model="openai:ft:text-embedding-3-small") == (
"openai",
"ft:text-embedding-3-small",
)
def test_infer_model_and_provider_errors() -> None:
"""Test error cases for model and provider inference."""
# Test missing provider
with pytest.raises(ValueError, match="Must specify either"):
_infer_model_and_provider("text-embedding-3-small")
# Test empty model
with pytest.raises(ValueError, match="Model name cannot be empty"):
_infer_model_and_provider("")
# Test empty provider with model
with pytest.raises(ValueError, match="Must specify either"):
_infer_model_and_provider("model", provider="")
# Test invalid provider
with pytest.raises(ValueError, match="Provider 'invalid' is not supported.") as exc:
_infer_model_and_provider("model", provider="invalid")
# Test provider list is in error
for provider in _SUPPORTED_PROVIDERS:
assert provider in str(exc.value)
@pytest.mark.parametrize(
"provider",
sorted(_SUPPORTED_PROVIDERS.keys()),
)
def test_supported_providers_package_names(provider: str) -> None:
"""Test that all supported providers have valid package names."""
package = _SUPPORTED_PROVIDERS[provider]
assert "-" not in package
assert package.startswith("langchain_")
assert package.islower()
def test_is_sorted() -> None:
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())

View File

@@ -0,0 +1,250 @@
"""Embeddings tests."""
import contextlib
import hashlib
import importlib
import warnings
import pytest
from langchain_core.embeddings import Embeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage.in_memory import InMemoryStore
class MockEmbeddings(Embeddings):
def embed_documents(self, texts: list[str]) -> list[list[float]]:
# Simulate embedding documents
embeddings: list[list[float]] = []
for text in texts:
if text == "RAISE_EXCEPTION":
msg = "Simulated embedding failure"
raise ValueError(msg)
embeddings.append([len(text), len(text) + 1])
return embeddings
def embed_query(self, text: str) -> list[float]:
# Simulate embedding a query
return [5.0, 6.0]
@pytest.fixture
def cache_embeddings() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings."""
store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings,
store,
namespace="test_namespace",
)
@pytest.fixture
def cache_embeddings_batch() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings with a batch_size of 3."""
store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings,
store,
namespace="test_namespace",
batch_size=3,
)
@pytest.fixture
def cache_embeddings_with_query() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings with query caching."""
doc_store = InMemoryStore()
query_store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings,
document_embedding_cache=doc_store,
namespace="test_namespace",
query_embedding_cache=query_store,
)
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"]
vectors = cache_embeddings.embed_documents(texts)
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
assert vectors == expected_vectors
keys = list(cache_embeddings.document_embedding_store.yield_keys())
assert len(keys) == 4
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
with contextlib.suppress(ValueError):
cache_embeddings_batch.embed_documents(texts)
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
# only the first batch of three embeddings should exist
assert len(keys) == 3
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = cache_embeddings.embed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
assert cache_embeddings.query_embedding_store is None
def test_embed_cached_query(cache_embeddings_with_query: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = cache_embeddings_with_query.embed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
assert len(keys) == 1
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"]
vectors = await cache_embeddings.aembed_documents(texts)
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
assert vectors == expected_vectors
keys = [
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
]
assert len(keys) == 4
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
async def test_aembed_documents_batch(
cache_embeddings_batch: CacheBackedEmbeddings,
) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
with contextlib.suppress(ValueError):
await cache_embeddings_batch.aembed_documents(texts)
keys = [
key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
]
# only the first batch of three embeddings should exist
assert len(keys) == 3
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = await cache_embeddings.aembed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
async def test_aembed_query_cached(
cache_embeddings_with_query: CacheBackedEmbeddings,
) -> None:
text = "query_text"
await cache_embeddings_with_query.aembed_query(text)
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
assert len(keys) == 1
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
def test_blake2b_encoder() -> None:
"""Test that the blake2b encoder is used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
cbe = CacheBackedEmbeddings.from_bytes_store(
emb,
store,
namespace="ns_",
key_encoder="blake2b",
)
text = "blake"
cbe.embed_documents([text])
# rebuild the key exactly as the library does
expected_key = "ns_" + hashlib.blake2b(text.encode()).hexdigest()
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
def test_sha256_encoder() -> None:
"""Test that the sha256 encoder is used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
cbe = CacheBackedEmbeddings.from_bytes_store(
emb,
store,
namespace="ns_",
key_encoder="sha256",
)
text = "foo"
cbe.embed_documents([text])
# rebuild the key exactly as the library does
expected_key = "ns_" + hashlib.sha256(text.encode()).hexdigest()
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
def test_sha512_encoder() -> None:
"""Test that the sha512 encoder is used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
cbe = CacheBackedEmbeddings.from_bytes_store(
emb,
store,
namespace="ns_",
key_encoder="sha512",
)
text = "foo"
cbe.embed_documents([text])
# rebuild the key exactly as the library does
expected_key = "ns_" + hashlib.sha512(text.encode()).hexdigest()
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
def test_sha1_warning_emitted_once() -> None:
"""Test that a warning is emitted when using SHA-1 as the default key encoder."""
module = importlib.import_module(CacheBackedEmbeddings.__module__)
# Create a *temporary* MonkeyPatch object whose effects disappear
# automatically when the with-block exits.
with pytest.MonkeyPatch.context() as mp:
# We're monkey patching the module to reset the `_warned_about_sha1` flag
# which may have been set while testing other parts of the codebase.
mp.setattr(module, "_warned_about_sha1", False, raising=False)
store = InMemoryStore()
emb = MockEmbeddings()
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
CacheBackedEmbeddings.from_bytes_store(emb, store) # triggers warning
CacheBackedEmbeddings.from_bytes_store(emb, store) # silent
sha1_msgs = [w for w in caught if "SHA-1" in str(w.message)]
assert len(sha1_msgs) == 1
def test_custom_encoder() -> None:
"""Test that a custom encoder can be used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
def custom_upper(text: str) -> str: # very simple demo encoder
return "CUSTOM_" + text.upper()
cbe = CacheBackedEmbeddings.from_bytes_store(emb, store, key_encoder=custom_upper)
txt = "x"
cbe.embed_documents([txt])
assert list(cbe.document_embedding_store.yield_keys()) == ["CUSTOM_X"]

View File

@@ -0,0 +1,10 @@
from langchain import embeddings
EXPECTED_ALL = [
"CacheBackedEmbeddings",
"init_embeddings",
]
def test_all_imports() -> None:
assert set(embeddings.__all__) == set(EXPECTED_ALL)

View File

@@ -0,0 +1 @@
"""Test prompt functionality."""

View File

@@ -0,0 +1,24 @@
from langchain import prompts
EXPECTED_ALL = [
"AIMessagePromptTemplate",
"BaseChatPromptTemplate",
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"HumanMessagePromptTemplate",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"MessagesPlaceholder",
"PromptTemplate",
"SemanticSimilarityExampleSelector",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"FewShotChatMessagePromptTemplate",
]
def test_all_imports() -> None:
assert set(prompts.__all__) == set(EXPECTED_ALL)

View File

@@ -0,0 +1,155 @@
import tempfile
from collections.abc import Generator
from pathlib import Path
import pytest
from langchain_core.stores import InvalidKeyException
from langchain.storage.file_system import LocalFileStore
@pytest.fixture
def file_store() -> Generator[LocalFileStore, None, None]:
# Create a temporary directory for testing
with tempfile.TemporaryDirectory() as temp_dir:
# Instantiate the LocalFileStore with the temporary directory as the root path
store = LocalFileStore(temp_dir)
yield store
def test_mset_and_mget(file_store: LocalFileStore) -> None:
# Set values for keys
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
file_store.mset(key_value_pairs)
# Get values for keys
values = file_store.mget(["key1", "key2"])
# Assert that the retrieved values match the original values
assert values == [b"value1", b"value2"]
@pytest.mark.parametrize(
("chmod_dir_s", "chmod_file_s"),
[("777", "666"), ("770", "660"), ("700", "600")],
)
def test_mset_chmod(chmod_dir_s: str, chmod_file_s: str) -> None:
chmod_dir = int(chmod_dir_s, base=8)
chmod_file = int(chmod_file_s, base=8)
# Create a temporary directory for testing
with tempfile.TemporaryDirectory() as temp_dir:
# Instantiate the LocalFileStore with a directory inside the temporary directory
# as the root path
file_store = LocalFileStore(
Path(temp_dir) / "store_dir",
chmod_dir=chmod_dir,
chmod_file=chmod_file,
)
# Set values for keys
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
file_store.mset(key_value_pairs)
# verify the permissions are set correctly
# (test only the standard user/group/other bits)
dir_path = file_store.root_path
file_path = file_store.root_path / "key1"
assert (dir_path.stat().st_mode & 0o777) == chmod_dir
assert (file_path.stat().st_mode & 0o777) == chmod_file
def test_mget_update_atime() -> None:
# Create a temporary directory for testing
with tempfile.TemporaryDirectory() as temp_dir:
# Instantiate the LocalFileStore with a directory inside the temporary directory
# as the root path
file_store = LocalFileStore(Path(temp_dir) / "store_dir", update_atime=True)
# Set values for keys
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
file_store.mset(key_value_pairs)
# Get original access time
file_path = file_store.root_path / "key1"
atime1 = file_path.stat().st_atime
# Get values for keys
_ = file_store.mget(["key1", "key2"])
# Make sure the filesystem access time has been updated
atime2 = file_path.stat().st_atime
assert atime2 != atime1
def test_mdelete(file_store: LocalFileStore) -> None:
# Set values for keys
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
file_store.mset(key_value_pairs)
# Delete keys
file_store.mdelete(["key1"])
# Check if the deleted key is present
values = file_store.mget(["key1"])
# Assert that the value is None after deletion
assert values == [None]
def test_set_invalid_key(file_store: LocalFileStore) -> None:
"""Test that an exception is raised when an invalid key is set."""
# Set a key-value pair
key = "crying-cat/😿"
value = b"This is a test value"
with pytest.raises(InvalidKeyException):
file_store.mset([(key, value)])
def test_set_key_and_verify_content(file_store: LocalFileStore) -> None:
"""Test that the content of the file is the same as the value set."""
# Set a key-value pair
key = "test_key"
value = b"This is a test value"
file_store.mset([(key, value)])
# Verify the content of the actual file
full_path = file_store._get_full_path(key)
assert full_path.exists()
assert full_path.read_bytes() == b"This is a test value"
def test_yield_keys(file_store: LocalFileStore) -> None:
# Set values for keys
key_value_pairs = [("key1", b"value1"), ("subdir/key2", b"value2")]
file_store.mset(key_value_pairs)
# Iterate over keys
keys = list(file_store.yield_keys())
# Assert that the yielded keys match the expected keys
expected_keys = ["key1", str(Path("subdir") / "key2")]
assert keys == expected_keys
def test_catches_forbidden_keys(file_store: LocalFileStore) -> None:
"""Make sure we raise exception on keys that are not allowed; e.g., absolute path"""
with pytest.raises(InvalidKeyException):
file_store.mset([("/etc", b"value1")])
with pytest.raises(InvalidKeyException):
list(file_store.yield_keys("/etc/passwd"))
with pytest.raises(InvalidKeyException):
file_store.mget(["/etc/passwd"])
# check relative paths
with pytest.raises(InvalidKeyException):
list(file_store.yield_keys(".."))
with pytest.raises(InvalidKeyException):
file_store.mget(["../etc/passwd"])
with pytest.raises(InvalidKeyException):
file_store.mset([("../etc", b"value1")])
with pytest.raises(InvalidKeyException):
list(file_store.yield_keys("../etc/passwd"))

View File

@@ -0,0 +1,15 @@
from langchain import storage
EXPECTED_ALL = [
"EncoderBackedStore",
"InMemoryStore",
"InMemoryByteStore",
"LocalFileStore",
"InvalidKeyException",
"create_lc_store",
"create_kv_docstore",
]
def test_all_imports() -> None:
assert set(storage.__all__) == set(EXPECTED_ALL)

View File

@@ -0,0 +1,37 @@
import tempfile
from collections.abc import Generator
from typing import cast
import pytest
from langchain_core.documents import Document
from langchain.storage._lc_store import create_kv_docstore, create_lc_store
from langchain.storage.file_system import LocalFileStore
@pytest.fixture
def file_store() -> Generator[LocalFileStore, None, None]:
# Create a temporary directory for testing
with tempfile.TemporaryDirectory() as temp_dir:
# Instantiate the LocalFileStore with the temporary directory as the root path
store = LocalFileStore(temp_dir)
yield store
def test_create_lc_store(file_store: LocalFileStore) -> None:
"""Test that a docstore is created from a base store."""
docstore = create_lc_store(file_store)
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
fetched_doc = cast("Document", docstore.mget(["key1"])[0])
assert fetched_doc.page_content == "hello"
assert fetched_doc.metadata == {"key": "value"}
def test_create_kv_store(file_store: LocalFileStore) -> None:
"""Test that a docstore is created from a base store."""
docstore = create_kv_docstore(file_store)
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
fetched_doc = docstore.mget(["key1"])[0]
assert isinstance(fetched_doc, Document)
assert fetched_doc.page_content == "hello"
assert fetched_doc.metadata == {"key": "value"}

View File

@@ -0,0 +1,46 @@
from typing import Any
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
class AnyStr(str):
__slots__ = ()
def __eq__(self, other: object) -> bool:
return isinstance(other, str)
# The code below creates version of pydantic models
# that will work in unit tests with AnyStr as id field
# Please note that the `id` field is assigned AFTER the model is created
# to workaround an issue with pydantic ignoring the __eq__ method on
# subclassed strings.
def _AnyIdDocument(**kwargs: Any) -> Document:
"""Create a document with an id field."""
message = Document(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field."""
message = AIMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
"""Create ai message with an any id field."""
message = AIMessageChunk(**kwargs)
message.id = AnyStr()
return message
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
"""Create a human with an any id field."""
message = HumanMessage(**kwargs)
message.id = AnyStr()
return message

View File

@@ -0,0 +1,35 @@
"""A unit test meant to catch accidental introduction of non-optional dependencies."""
from collections.abc import Mapping
from pathlib import Path
from typing import Any
import pytest
import toml
from packaging.requirements import Requirement
HERE = Path(__file__).parent
PYPROJECT_TOML = HERE / "../../pyproject.toml"
@pytest.fixture
def uv_conf() -> dict[str, Any]:
"""Load the pyproject.toml file."""
with PYPROJECT_TOML.open() as f:
return toml.load(f)
def test_required_dependencies(uv_conf: Mapping[str, Any]) -> None:
"""A test that checks if a new non-optional dependency is being introduced.
If this test is triggered, it means that a contributor is trying to introduce a new
required dependency. This should be avoided in most situations.
"""
# Get the dependencies from the [tool.poetry.dependencies] section
dependencies = uv_conf["project"]["dependencies"]
required_dependencies = {Requirement(dep).name for dep in dependencies}
assert sorted(required_dependencies) == sorted(
["langchain-core", "langchain-text-splitters", "langgraph", "pydantic"]
)

View File

@@ -0,0 +1,60 @@
import importlib
import warnings
from pathlib import Path
# Attempt to recursively import all modules in langchain
PKG_ROOT = Path(__file__).parent.parent.parent
def test_import_all() -> None:
"""Generate the public API for this package."""
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=UserWarning)
library_code = PKG_ROOT / "langchain"
for path in library_code.rglob("*.py"):
# Calculate the relative path to the module
module_name = (
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
)
if module_name.endswith("__init__"):
# Without init
module_name = module_name.rsplit(".", 1)[0]
mod = importlib.import_module(module_name)
all_attrs = getattr(mod, "__all__", [])
for name in all_attrs:
# Attempt to import the name from the module
try:
obj = getattr(mod, name)
assert obj is not None
except Exception as e:
msg = f"Could not import {module_name}.{name}"
raise AssertionError(msg) from e
def test_import_all_using_dir() -> None:
"""Generate the public API for this package."""
library_code = PKG_ROOT / "langchain"
for path in library_code.rglob("*.py"):
# Calculate the relative path to the module
module_name = (
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
)
if module_name.endswith("__init__"):
# Without init
module_name = module_name.rsplit(".", 1)[0]
try:
mod = importlib.import_module(module_name)
except ModuleNotFoundError as e:
msg = f"Could not import {module_name}"
raise ModuleNotFoundError(msg) from e
attributes = dir(mod)
for name in attributes:
if name.strip().startswith("_"):
continue
# Attempt to import the name from the module
getattr(mod, name)

View File

@@ -0,0 +1,11 @@
import pytest
import pytest_socket
import requests
def test_socket_disabled() -> None:
"""This test should fail."""
with pytest.raises(pytest_socket.SocketBlockedError):
# Ignore S113 since we don't need a timeout here as the request
# should fail immediately
requests.get("https://www.example.com") # noqa: S113

View File

@@ -0,0 +1,14 @@
from langchain import tools
EXPECTED_ALL = {
"BaseTool",
"InjectedToolArg",
"InjectedToolCallId",
"Tool",
"ToolException",
"tool",
}
def test_all_imports() -> None:
assert set(tools.__all__) == EXPECTED_ALL

4562
libs/langchain_v1/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff