mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-06 05:25:04 +00:00
feat(langchain): v1 scaffolding (#32166)
This PR adds scaffolding for langchain 1.0 entry package. Most contents have been removed. Currently remaining entrypoints for: * chat models * embedding models * memory -> trimming messages, filtering messages and counting tokens [we may remove this] * prompts -> we may remove some prompts * storage: primarily to support cache backed embeddings, may remove the kv store * tools -> report tool primitives Things to be added: * Selected agent implementations * Selected workflows * Common primitives: messages, Document * Primitives for type hinting: BaseChatModel, BaseEmbeddings * Selected retrievers * Selected text splitters Things to be removed: * Globals needs to be removed (needs an update in langchain core) Todos: * TBD indexing api (requires sqlalchemy which we don't want as a dependency) * Be explicit about public/private interfaces (e.g., likely rename chat_models.base.py to something more internal) * Remove dockerfiles * Update module doc-strings and README.md
This commit is contained in:
1
.github/scripts/check_diff.py
vendored
1
.github/scripts/check_diff.py
vendored
@@ -16,6 +16,7 @@ LANGCHAIN_DIRS = [
|
||||
"libs/core",
|
||||
"libs/text-splitters",
|
||||
"libs/langchain",
|
||||
"libs/langchain_v1",
|
||||
]
|
||||
|
||||
# when set to True, we are ignoring core dependents
|
||||
|
6
libs/langchain_v1/.dockerignore
Normal file
6
libs/langchain_v1/.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
||||
.venv
|
||||
.github
|
||||
.git
|
||||
.mypy_cache
|
||||
.pytest_cache
|
||||
Dockerfile
|
21
libs/langchain_v1/LICENSE
Normal file
21
libs/langchain_v1/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
101
libs/langchain_v1/Makefile
Normal file
101
libs/langchain_v1/Makefile
Normal file
@@ -0,0 +1,101 @@
|
||||
.PHONY: all clean docs_build docs_clean docs_linkcheck api_docs_build api_docs_clean api_docs_linkcheck format lint test tests test_watch integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
######################
|
||||
# TESTING AND COVERAGE
|
||||
######################
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
.EXPORT_ALL_VARIABLES:
|
||||
UV_FROZEN = true
|
||||
|
||||
# Run unit tests and generate a coverage report.
|
||||
coverage:
|
||||
uv run --group test pytest --cov \
|
||||
--cov-config=.coveragerc \
|
||||
--cov-report xml \
|
||||
--cov-report term-missing:skip-covered \
|
||||
$(TEST_FILE)
|
||||
|
||||
test tests:
|
||||
uv run --group test pytest -n auto --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
extended_tests:
|
||||
uv run --group test pytest --disable-socket --allow-unix-socket --only-extended tests/unit_tests
|
||||
|
||||
test_watch:
|
||||
uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --disable-warnings tests/unit_tests
|
||||
|
||||
test_watch_extended:
|
||||
uv run --group test ptw --snapshot-update --now . -- -x --disable-socket --allow-unix-socket --only-extended tests/unit_tests
|
||||
|
||||
integration_tests:
|
||||
uv run --group test --group test_integration pytest tests/integration_tests
|
||||
|
||||
docker_tests:
|
||||
docker build -t my-langchain-image:test .
|
||||
docker run --rm my-langchain-image:test
|
||||
|
||||
check_imports: $(shell find langchain -name '*.py')
|
||||
uv run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/langchain --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --all-groups mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
uv run --all-groups codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
uv run --all-groups codespell --toml pyproject.toml -w
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '===================='
|
||||
@echo 'clean - run docs_clean and api_docs_clean'
|
||||
@echo 'docs_build - build the documentation'
|
||||
@echo 'docs_clean - clean the documentation build artifacts'
|
||||
@echo 'docs_linkcheck - run linkchecker on the documentation'
|
||||
@echo 'api_docs_build - build the API Reference documentation'
|
||||
@echo 'api_docs_clean - clean the API Reference documentation build artifacts'
|
||||
@echo 'api_docs_linkcheck - run linkchecker on the API Reference documentation'
|
||||
@echo '-- LINTING --'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'spell_check - run codespell on the project'
|
||||
@echo 'spell_fix - run codespell on the project and fix the errors'
|
||||
@echo '-- TESTS --'
|
||||
@echo 'coverage - run unit tests and generate coverage report'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests (alias for "make test")'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
||||
@echo 'extended_tests - run only extended unit tests'
|
||||
@echo 'test_watch - run unit tests in watch mode'
|
||||
@echo 'integration_tests - run integration tests'
|
||||
@echo 'docker_tests - run unit tests in docker'
|
||||
@echo '-- DOCUMENTATION tasks are from the top-level Makefile --'
|
91
libs/langchain_v1/README.md
Normal file
91
libs/langchain_v1/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# 🦜️🔗 LangChain
|
||||
|
||||
⚡ Building applications with LLMs through composability ⚡
|
||||
|
||||
[](https://github.com/langchain-ai/langchain/releases)
|
||||
[](https://github.com/langchain-ai/langchain/actions/workflows/lint.yml)
|
||||
[](https://github.com/langchain-ai/langchain/actions/workflows/test.yml)
|
||||
[](https://pepy.tech/project/langchain)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://twitter.com/langchainai)
|
||||
[](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langchain-ai/langchain)
|
||||
[](https://codespaces.new/langchain-ai/langchain)
|
||||
[](https://star-history.com/#langchain-ai/langchain)
|
||||
[](https://libraries.io/github/langchain-ai/langchain)
|
||||
[](https://github.com/langchain-ai/langchain/issues)
|
||||
|
||||
|
||||
Looking for the JS/TS version? Check out [LangChain.js](https://github.com/langchain-ai/langchainjs).
|
||||
|
||||
To help you ship LangChain apps to production faster, check out [LangSmith](https://smith.langchain.com).
|
||||
[LangSmith](https://smith.langchain.com) is a unified developer platform for building, testing, and monitoring LLM applications.
|
||||
Fill out [this form](https://www.langchain.com/contact-sales) to speak with our sales team.
|
||||
|
||||
## Quick Install
|
||||
|
||||
`pip install langchain`
|
||||
or
|
||||
`pip install langsmith && conda install langchain -c conda-forge`
|
||||
|
||||
## 🤔 What is this?
|
||||
|
||||
Large language models (LLMs) are emerging as a transformative technology, enabling developers to build applications that they previously could not. However, using these LLMs in isolation is often insufficient for creating a truly powerful app - the real power comes when you can combine them with other sources of computation or knowledge.
|
||||
|
||||
This library aims to assist in the development of those types of applications. Common examples of these applications include:
|
||||
|
||||
**❓ Question answering with RAG**
|
||||
|
||||
- [Documentation](https://python.langchain.com/docs/use_cases/question_answering/)
|
||||
- End-to-end Example: [Chat LangChain](https://chat.langchain.com) and [repo](https://github.com/langchain-ai/chat-langchain)
|
||||
|
||||
**🧱 Extracting structured output**
|
||||
|
||||
- [Documentation](https://python.langchain.com/docs/use_cases/extraction/)
|
||||
- End-to-end Example: [SQL Llama2 Template](https://github.com/langchain-ai/langchain-extract/)
|
||||
|
||||
**🤖 Chatbots**
|
||||
|
||||
- [Documentation](https://python.langchain.com/docs/use_cases/chatbots)
|
||||
- End-to-end Example: [Web LangChain (web researcher chatbot)](https://weblangchain.vercel.app) and [repo](https://github.com/langchain-ai/weblangchain)
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
Please see [here](https://python.langchain.com) for full documentation on:
|
||||
|
||||
- Getting started (installation, setting up the environment, simple examples)
|
||||
- How-To examples (demos, integrations, helper functions)
|
||||
- Reference (full API docs)
|
||||
- Resources (high-level explanation of core concepts)
|
||||
|
||||
## 🚀 What can this help with?
|
||||
|
||||
There are five main areas that LangChain is designed to help with.
|
||||
These are, in increasing order of complexity:
|
||||
|
||||
**📃 Models and Prompts:**
|
||||
|
||||
This includes prompt management, prompt optimization, a generic interface for all LLMs, and common utilities for working with chat models and LLMs.
|
||||
|
||||
**🔗 Chains:**
|
||||
|
||||
Chains go beyond a single LLM call and involve sequences of calls (whether to an LLM or a different utility). LangChain provides a standard interface for chains, lots of integrations with other tools, and end-to-end chains for common applications.
|
||||
|
||||
**📚 Retrieval Augmented Generation:**
|
||||
|
||||
Retrieval Augmented Generation involves specific types of chains that first interact with an external data source to fetch data for use in the generation step. Examples include summarization of long pieces of text and question/answering over specific data sources.
|
||||
|
||||
**🤖 Agents:**
|
||||
|
||||
Agents involve an LLM making decisions about which Actions to take, taking that Action, seeing an Observation, and repeating that until done. LangChain provides a standard interface for agents, a selection of agents to choose from, and examples of end-to-end agents.
|
||||
|
||||
**🧐 Evaluation:**
|
||||
|
||||
[BETA] Generative models are notoriously hard to evaluate with traditional metrics. One new way of evaluating them is using language models themselves to do the evaluation. LangChain provides some prompts/chains for assisting in this.
|
||||
|
||||
For more information on these concepts, please see our [full documentation](https://python.langchain.com).
|
||||
|
||||
## 💁 Contributing
|
||||
|
||||
As an open-source project in a rapidly developing field, we are extremely open to contributions, whether it be in the form of a new feature, improved infrastructure, or better documentation.
|
||||
|
||||
For detailed information on how to contribute, see the [Contributing Guide](https://python.langchain.com/docs/contributing/).
|
5
libs/langchain_v1/extended_testing_deps.txt
Normal file
5
libs/langchain_v1/extended_testing_deps.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
-e ../partners/openai
|
||||
-e ../partners/anthropic
|
||||
-e ../partners/fireworks
|
||||
-e ../partners/mistralai
|
||||
-e ../partners/groq
|
29
libs/langchain_v1/langchain/__init__.py
Normal file
29
libs/langchain_v1/langchain/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Main entrypoint into package."""
|
||||
|
||||
from importlib import metadata
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
||||
del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any: # noqa: ANN401
|
||||
"""Get an attribute from the package."""
|
||||
if name == "verbose":
|
||||
from langchain.globals import _verbose
|
||||
|
||||
return _verbose
|
||||
if name == "debug":
|
||||
from langchain.globals import _debug
|
||||
|
||||
return _debug
|
||||
if name == "llm_cache":
|
||||
from langchain.globals import _llm_cache
|
||||
|
||||
return _llm_cache
|
||||
msg = f"Could not find: {name}"
|
||||
raise AttributeError(msg)
|
24
libs/langchain_v1/langchain/chat_models/__init__.py
Normal file
24
libs/langchain_v1/langchain/chat_models/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""**Chat Models** are a variation on language models.
|
||||
|
||||
While Chat Models use language models under the hood, the interface they expose
|
||||
is a bit different. Rather than expose a "text in, text out" API, they expose
|
||||
an interface where "chat messages" are the inputs and outputs.
|
||||
|
||||
**Class hierarchy:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
BaseLanguageModel --> BaseChatModel --> <name> # Examples: ChatOpenAI, ChatGooglePalm
|
||||
|
||||
**Main helpers:**
|
||||
|
||||
.. code-block::
|
||||
|
||||
AIMessage, BaseMessage, HumanMessage
|
||||
""" # noqa: E501
|
||||
|
||||
from langchain.chat_models.base import init_chat_model
|
||||
|
||||
__all__ = [
|
||||
"init_chat_model",
|
||||
]
|
946
libs/langchain_v1/langchain/chat_models/base.py
Normal file
946
libs/langchain_v1/langchain/chat_models/base.py
Normal file
@@ -0,0 +1,946 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from importlib import util
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core.language_models import (
|
||||
BaseChatModel,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.messages import AnyMessage, BaseMessage
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
|
||||
from typing_extensions import TypeAlias, override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tracers import RunLog, RunLogPatch
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@overload
|
||||
def init_chat_model(
|
||||
model: str,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Literal[None] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatModel: ...
|
||||
|
||||
|
||||
@overload
|
||||
def init_chat_model(
|
||||
model: Literal[None] = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Literal[None] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
|
||||
|
||||
@overload
|
||||
def init_chat_model(
|
||||
model: Optional[str] = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
|
||||
|
||||
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
|
||||
# name to the supported list in the docstring below. Do *not* change the order of the
|
||||
# existing providers.
|
||||
def init_chat_model(
|
||||
model: Optional[str] = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Optional[
|
||||
Union[Literal["any"], list[str], tuple[str, ...]]
|
||||
] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[BaseChatModel, _ConfigurableModel]:
|
||||
"""Initialize a ChatModel from the model name and provider.
|
||||
|
||||
**Note:** Must have the integration package corresponding to the model provider
|
||||
installed.
|
||||
|
||||
Args:
|
||||
model: The name of the model, e.g. "o3-mini", "claude-3-5-sonnet-latest". You can
|
||||
also specify model and model provider in a single argument using
|
||||
'{model_provider}:{model}' format, e.g. "openai:o1".
|
||||
model_provider: The model provider if not specified as part of model arg (see
|
||||
above). Supported model_provider values and the corresponding integration
|
||||
package are:
|
||||
|
||||
- 'openai' -> langchain-openai
|
||||
- 'anthropic' -> langchain-anthropic
|
||||
- 'azure_openai' -> langchain-openai
|
||||
- 'azure_ai' -> langchain-azure-ai
|
||||
- 'google_vertexai' -> langchain-google-vertexai
|
||||
- 'google_genai' -> langchain-google-genai
|
||||
- 'bedrock' -> langchain-aws
|
||||
- 'bedrock_converse' -> langchain-aws
|
||||
- 'cohere' -> langchain-cohere
|
||||
- 'fireworks' -> langchain-fireworks
|
||||
- 'together' -> langchain-together
|
||||
- 'mistralai' -> langchain-mistralai
|
||||
- 'huggingface' -> langchain-huggingface
|
||||
- 'groq' -> langchain-groq
|
||||
- 'ollama' -> langchain-ollama
|
||||
- 'google_anthropic_vertex' -> langchain-google-vertexai
|
||||
- 'deepseek' -> langchain-deepseek
|
||||
- 'ibm' -> langchain-ibm
|
||||
- 'nvidia' -> langchain-nvidia-ai-endpoints
|
||||
- 'xai' -> langchain-xai
|
||||
- 'perplexity' -> langchain-perplexity
|
||||
|
||||
Will attempt to infer model_provider from model if not specified. The
|
||||
following providers will be inferred based on these model prefixes:
|
||||
|
||||
- 'gpt-3...' | 'gpt-4...' | 'o1...' -> 'openai'
|
||||
- 'claude...' -> 'anthropic'
|
||||
- 'amazon....' -> 'bedrock'
|
||||
- 'gemini...' -> 'google_vertexai'
|
||||
- 'command...' -> 'cohere'
|
||||
- 'accounts/fireworks...' -> 'fireworks'
|
||||
- 'mistral...' -> 'mistralai'
|
||||
- 'deepseek...' -> 'deepseek'
|
||||
- 'grok...' -> 'xai'
|
||||
- 'sonar...' -> 'perplexity'
|
||||
configurable_fields: Which model parameters are
|
||||
configurable:
|
||||
|
||||
- None: No configurable fields.
|
||||
- "any": All fields are configurable. *See Security Note below.*
|
||||
- Union[List[str], Tuple[str, ...]]: Specified fields are configurable.
|
||||
|
||||
Fields are assumed to have config_prefix stripped if there is a
|
||||
config_prefix. If model is specified, then defaults to None. If model is
|
||||
not specified, then defaults to ``("model", "model_provider")``.
|
||||
|
||||
***Security Note***: Setting ``configurable_fields="any"`` means fields like
|
||||
api_key, base_url, etc. can be altered at runtime, potentially redirecting
|
||||
model requests to a different service/user. Make sure that if you're
|
||||
accepting untrusted configurations that you enumerate the
|
||||
``configurable_fields=(...)`` explicitly.
|
||||
|
||||
config_prefix: If config_prefix is a non-empty string then model will be
|
||||
configurable at runtime via the
|
||||
``config["configurable"]["{config_prefix}_{param}"]`` keys. If
|
||||
config_prefix is an empty string then model will be configurable via
|
||||
``config["configurable"]["{param}"]``.
|
||||
temperature: Model temperature.
|
||||
max_tokens: Max output tokens.
|
||||
timeout: The maximum time (in seconds) to wait for a response from the model
|
||||
before canceling the request.
|
||||
max_retries: The maximum number of attempts the system will make to resend a
|
||||
request if it fails due to issues like network timeouts or rate limits.
|
||||
base_url: The URL of the API endpoint where requests are sent.
|
||||
rate_limiter: A ``BaseRateLimiter`` to space out requests to avoid exceeding
|
||||
rate limits.
|
||||
kwargs: Additional model-specific keyword args to pass to
|
||||
``<<selected ChatModel>>.__init__(model=model_name, **kwargs)``.
|
||||
|
||||
Returns:
|
||||
A BaseChatModel corresponding to the model_name and model_provider specified if
|
||||
configurability is inferred to be False. If configurable, a chat model emulator
|
||||
that initializes the underlying model at runtime once a config is passed in.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_provider cannot be inferred or isn't supported.
|
||||
ImportError: If the model provider integration package is not installed.
|
||||
|
||||
.. dropdown:: Init non-configurable model
|
||||
:open:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
o3_mini = init_chat_model("openai:o3-mini", temperature=0)
|
||||
claude_sonnet = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=0)
|
||||
gemini_2_flash = init_chat_model("google_vertexai:gemini-2.0-flash", temperature=0)
|
||||
|
||||
o3_mini.invoke("what's your name")
|
||||
claude_sonnet.invoke("what's your name")
|
||||
gemini_2_flash.invoke("what's your name")
|
||||
|
||||
|
||||
.. dropdown:: Partially configurable model with no default
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pip install langchain langchain-openai langchain-anthropic
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
# We don't need to specify configurable=True if a model isn't specified.
|
||||
configurable_model = init_chat_model(temperature=0)
|
||||
|
||||
configurable_model.invoke(
|
||||
"what's your name",
|
||||
config={"configurable": {"model": "gpt-4o"}}
|
||||
)
|
||||
# GPT-4o response
|
||||
|
||||
configurable_model.invoke(
|
||||
"what's your name",
|
||||
config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
|
||||
)
|
||||
# claude-3.5 sonnet response
|
||||
|
||||
.. dropdown:: Fully configurable model with a default
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pip install langchain langchain-openai langchain-anthropic
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
configurable_model_with_default = init_chat_model(
|
||||
"openai:gpt-4o",
|
||||
configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime.
|
||||
config_prefix="foo",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
configurable_model_with_default.invoke("what's your name")
|
||||
# GPT-4o response with temperature 0
|
||||
|
||||
configurable_model_with_default.invoke(
|
||||
"what's your name",
|
||||
config={
|
||||
"configurable": {
|
||||
"foo_model": "anthropic:claude-3-5-sonnet-20240620",
|
||||
"foo_temperature": 0.6
|
||||
}
|
||||
}
|
||||
)
|
||||
# Claude-3.5 sonnet response with temperature 0.6
|
||||
|
||||
.. dropdown:: Bind tools to a configurable model
|
||||
|
||||
You can call any ChatModel declarative methods on a configurable model in the
|
||||
same way that you would with a normal model.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pip install langchain langchain-openai langchain-anthropic
|
||||
from langchain.chat_models import init_chat_model
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
'''Get the current weather in a given location'''
|
||||
|
||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
||||
|
||||
class GetPopulation(BaseModel):
|
||||
'''Get the current population in a given location'''
|
||||
|
||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
||||
|
||||
configurable_model = init_chat_model(
|
||||
"gpt-4o",
|
||||
configurable_fields=("model", "model_provider"),
|
||||
temperature=0
|
||||
)
|
||||
|
||||
configurable_model_with_tools = configurable_model.bind_tools([GetWeather, GetPopulation])
|
||||
configurable_model_with_tools.invoke(
|
||||
"Which city is hotter today and which is bigger: LA or NY?"
|
||||
)
|
||||
# GPT-4o response with tool calls
|
||||
|
||||
configurable_model_with_tools.invoke(
|
||||
"Which city is hotter today and which is bigger: LA or NY?",
|
||||
config={"configurable": {"model": "claude-3-5-sonnet-20240620"}}
|
||||
)
|
||||
# Claude-3.5 sonnet response with tools
|
||||
|
||||
.. versionadded:: 0.2.7
|
||||
|
||||
.. versionchanged:: 0.2.8
|
||||
|
||||
Support for ``configurable_fields`` and ``config_prefix`` added.
|
||||
|
||||
.. versionchanged:: 0.2.12
|
||||
|
||||
Support for Ollama via langchain-ollama package added
|
||||
(langchain_ollama.ChatOllama). Previously,
|
||||
the now-deprecated langchain-community version of Ollama was imported
|
||||
(langchain_community.chat_models.ChatOllama).
|
||||
|
||||
Support for AWS Bedrock models via the Converse API added
|
||||
(model_provider="bedrock_converse").
|
||||
|
||||
.. versionchanged:: 0.3.5
|
||||
|
||||
Out of beta.
|
||||
|
||||
.. versionchanged:: 0.3.19
|
||||
|
||||
Support for Deepseek, IBM, Nvidia, and xAI models added.
|
||||
|
||||
""" # noqa: E501
|
||||
if not model and not configurable_fields:
|
||||
configurable_fields = ("model", "model_provider")
|
||||
config_prefix = config_prefix or ""
|
||||
if config_prefix and not configurable_fields:
|
||||
warnings.warn(
|
||||
f"{config_prefix=} has been set but no fields are configurable. Set "
|
||||
f"`configurable_fields=(...)` to specify the model params that are "
|
||||
f"configurable.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not configurable_fields:
|
||||
return _init_chat_model_helper(
|
||||
cast("str", model),
|
||||
model_provider=model_provider,
|
||||
**kwargs,
|
||||
)
|
||||
if model:
|
||||
kwargs["model"] = model
|
||||
if model_provider:
|
||||
kwargs["model_provider"] = model_provider
|
||||
return _ConfigurableModel(
|
||||
default_config=kwargs,
|
||||
config_prefix=config_prefix,
|
||||
configurable_fields=configurable_fields,
|
||||
)
|
||||
|
||||
|
||||
def _init_chat_model_helper(
|
||||
model: str,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatModel:
|
||||
model, model_provider = _parse_model(model, model_provider)
|
||||
if model_provider == "openai":
|
||||
_check_pkg("langchain_openai")
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(model=model, **kwargs)
|
||||
if model_provider == "anthropic":
|
||||
_check_pkg("langchain_anthropic")
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||
if model_provider == "azure_openai":
|
||||
_check_pkg("langchain_openai")
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
return AzureChatOpenAI(model=model, **kwargs)
|
||||
if model_provider == "azure_ai":
|
||||
_check_pkg("langchain_azure_ai")
|
||||
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
||||
|
||||
return AzureAIChatCompletionsModel(model=model, **kwargs)
|
||||
if model_provider == "cohere":
|
||||
_check_pkg("langchain_cohere")
|
||||
from langchain_cohere import ChatCohere
|
||||
|
||||
return ChatCohere(model=model, **kwargs)
|
||||
if model_provider == "google_vertexai":
|
||||
_check_pkg("langchain_google_vertexai")
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
return ChatVertexAI(model=model, **kwargs)
|
||||
if model_provider == "google_genai":
|
||||
_check_pkg("langchain_google_genai")
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
||||
if model_provider == "fireworks":
|
||||
_check_pkg("langchain_fireworks")
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
return ChatFireworks(model=model, **kwargs)
|
||||
if model_provider == "ollama":
|
||||
try:
|
||||
_check_pkg("langchain_ollama")
|
||||
from langchain_ollama import ChatOllama
|
||||
except ImportError:
|
||||
# For backwards compatibility
|
||||
try:
|
||||
_check_pkg("langchain_community")
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
except ImportError:
|
||||
# If both langchain-ollama and langchain-community aren't available,
|
||||
# raise an error related to langchain-ollama
|
||||
_check_pkg("langchain_ollama")
|
||||
|
||||
return ChatOllama(model=model, **kwargs)
|
||||
if model_provider == "together":
|
||||
_check_pkg("langchain_together")
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
return ChatTogether(model=model, **kwargs)
|
||||
if model_provider == "mistralai":
|
||||
_check_pkg("langchain_mistralai")
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
|
||||
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
|
||||
if model_provider == "huggingface":
|
||||
_check_pkg("langchain_huggingface")
|
||||
from langchain_huggingface import ChatHuggingFace
|
||||
|
||||
return ChatHuggingFace(model_id=model, **kwargs)
|
||||
if model_provider == "groq":
|
||||
_check_pkg("langchain_groq")
|
||||
from langchain_groq import ChatGroq
|
||||
|
||||
return ChatGroq(model=model, **kwargs)
|
||||
if model_provider == "bedrock":
|
||||
_check_pkg("langchain_aws")
|
||||
from langchain_aws import ChatBedrock
|
||||
|
||||
return ChatBedrock(model_id=model, **kwargs)
|
||||
if model_provider == "bedrock_converse":
|
||||
_check_pkg("langchain_aws")
|
||||
from langchain_aws import ChatBedrockConverse
|
||||
|
||||
return ChatBedrockConverse(model=model, **kwargs)
|
||||
if model_provider == "google_anthropic_vertex":
|
||||
_check_pkg("langchain_google_vertexai")
|
||||
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
|
||||
|
||||
return ChatAnthropicVertex(model=model, **kwargs)
|
||||
if model_provider == "deepseek":
|
||||
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
return ChatDeepSeek(model=model, **kwargs)
|
||||
if model_provider == "nvidia":
|
||||
_check_pkg("langchain_nvidia_ai_endpoints")
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
|
||||
return ChatNVIDIA(model=model, **kwargs)
|
||||
if model_provider == "ibm":
|
||||
_check_pkg("langchain_ibm")
|
||||
from langchain_ibm import ChatWatsonx
|
||||
|
||||
return ChatWatsonx(model_id=model, **kwargs)
|
||||
if model_provider == "xai":
|
||||
_check_pkg("langchain_xai")
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
return ChatXAI(model=model, **kwargs)
|
||||
if model_provider == "perplexity":
|
||||
_check_pkg("langchain_perplexity")
|
||||
from langchain_perplexity import ChatPerplexity
|
||||
|
||||
return ChatPerplexity(model=model, **kwargs)
|
||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||
msg = (
|
||||
f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
_SUPPORTED_PROVIDERS = {
|
||||
"openai",
|
||||
"anthropic",
|
||||
"azure_openai",
|
||||
"azure_ai",
|
||||
"cohere",
|
||||
"google_vertexai",
|
||||
"google_genai",
|
||||
"fireworks",
|
||||
"ollama",
|
||||
"together",
|
||||
"mistralai",
|
||||
"huggingface",
|
||||
"groq",
|
||||
"bedrock",
|
||||
"bedrock_converse",
|
||||
"google_anthropic_vertex",
|
||||
"deepseek",
|
||||
"ibm",
|
||||
"xai",
|
||||
"perplexity",
|
||||
}
|
||||
|
||||
|
||||
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
||||
return "openai"
|
||||
if model_name.startswith("claude"):
|
||||
return "anthropic"
|
||||
if model_name.startswith("command"):
|
||||
return "cohere"
|
||||
if model_name.startswith("accounts/fireworks"):
|
||||
return "fireworks"
|
||||
if model_name.startswith("gemini"):
|
||||
return "google_vertexai"
|
||||
if model_name.startswith("amazon."):
|
||||
return "bedrock"
|
||||
if model_name.startswith("mistral"):
|
||||
return "mistralai"
|
||||
if model_name.startswith("deepseek"):
|
||||
return "deepseek"
|
||||
if model_name.startswith("grok"):
|
||||
return "xai"
|
||||
if model_name.startswith("sonar"):
|
||||
return "perplexity"
|
||||
return None
|
||||
|
||||
|
||||
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
|
||||
if (
|
||||
not model_provider
|
||||
and ":" in model
|
||||
and model.split(":")[0] in _SUPPORTED_PROVIDERS
|
||||
):
|
||||
model_provider = model.split(":")[0]
|
||||
model = ":".join(model.split(":")[1:])
|
||||
model_provider = model_provider or _attempt_infer_model_provider(model)
|
||||
if not model_provider:
|
||||
msg = (
|
||||
f"Unable to infer model provider for {model=}, please specify "
|
||||
f"model_provider directly."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
model_provider = model_provider.replace("-", "_").lower()
|
||||
return model, model_provider
|
||||
|
||||
|
||||
def _check_pkg(pkg: str, *, pkg_kebab: Optional[str] = None) -> None:
|
||||
if not util.find_spec(pkg):
|
||||
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
|
||||
msg = (
|
||||
f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
def _remove_prefix(s: str, prefix: str) -> str:
|
||||
return s.removeprefix(prefix)
|
||||
|
||||
|
||||
_DECLARATIVE_METHODS = ("bind_tools", "with_structured_output")
|
||||
|
||||
|
||||
class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
default_config: Optional[dict] = None,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
|
||||
config_prefix: str = "",
|
||||
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
||||
) -> None:
|
||||
self._default_config: dict = default_config or {}
|
||||
self._configurable_fields: Union[Literal["any"], list[str]] = (
|
||||
configurable_fields
|
||||
if configurable_fields == "any"
|
||||
else list(configurable_fields)
|
||||
)
|
||||
self._config_prefix = (
|
||||
config_prefix + "_"
|
||||
if config_prefix and not config_prefix.endswith("_")
|
||||
else config_prefix
|
||||
)
|
||||
self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
|
||||
queued_declarative_operations,
|
||||
)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in _DECLARATIVE_METHODS:
|
||||
# Declarative operations that cannot be applied until after an actual model
|
||||
# object is instantiated. So instead of returning the actual operation,
|
||||
# we record the operation and its arguments in a queue. This queue is
|
||||
# then applied in order whenever we actually instantiate the model (in
|
||||
# self._model()).
|
||||
def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel:
|
||||
queued_declarative_operations = list(
|
||||
self._queued_declarative_operations,
|
||||
)
|
||||
queued_declarative_operations.append((name, args, kwargs))
|
||||
return _ConfigurableModel(
|
||||
default_config=dict(self._default_config),
|
||||
configurable_fields=list(self._configurable_fields)
|
||||
if isinstance(self._configurable_fields, list)
|
||||
else self._configurable_fields,
|
||||
config_prefix=self._config_prefix,
|
||||
queued_declarative_operations=queued_declarative_operations,
|
||||
)
|
||||
|
||||
return queue
|
||||
if self._default_config and (model := self._model()) and hasattr(model, name):
|
||||
return getattr(model, name)
|
||||
msg = f"{name} is not a BaseChatModel attribute"
|
||||
if self._default_config:
|
||||
msg += " and is not implemented on the default model"
|
||||
msg += "."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def _model(self, config: Optional[RunnableConfig] = None) -> Runnable:
|
||||
params = {**self._default_config, **self._model_params(config)}
|
||||
model = _init_chat_model_helper(**params)
|
||||
for name, args, kwargs in self._queued_declarative_operations:
|
||||
model = getattr(model, name)(*args, **kwargs)
|
||||
return model
|
||||
|
||||
def _model_params(self, config: Optional[RunnableConfig]) -> dict:
|
||||
config = ensure_config(config)
|
||||
model_params = {
|
||||
_remove_prefix(k, self._config_prefix): v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
if k.startswith(self._config_prefix)
|
||||
}
|
||||
if self._configurable_fields != "any":
|
||||
model_params = {
|
||||
k: v for k, v in model_params.items() if k in self._configurable_fields
|
||||
}
|
||||
return model_params
|
||||
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel:
|
||||
"""Bind config to a Runnable, returning a new Runnable."""
|
||||
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
|
||||
model_params = self._model_params(config)
|
||||
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
|
||||
remaining_config["configurable"] = {
|
||||
k: v
|
||||
for k, v in config.get("configurable", {}).items()
|
||||
if _remove_prefix(k, self._config_prefix) not in model_params
|
||||
}
|
||||
queued_declarative_operations = list(self._queued_declarative_operations)
|
||||
if remaining_config:
|
||||
queued_declarative_operations.append(
|
||||
(
|
||||
"with_config",
|
||||
(),
|
||||
{"config": remaining_config},
|
||||
),
|
||||
)
|
||||
return _ConfigurableModel(
|
||||
default_config={**self._default_config, **model_params},
|
||||
configurable_fields=list(self._configurable_fields)
|
||||
if isinstance(self._configurable_fields, list)
|
||||
else self._configurable_fields,
|
||||
config_prefix=self._config_prefix,
|
||||
queued_declarative_operations=queued_declarative_operations,
|
||||
)
|
||||
|
||||
@property
|
||||
def InputType(self) -> TypeAlias:
|
||||
"""Get the input type for this runnable."""
|
||||
from langchain_core.prompt_values import (
|
||||
ChatPromptValueConcrete,
|
||||
StringPromptValue,
|
||||
)
|
||||
|
||||
# This is a version of LanguageModelInput which replaces the abstract
|
||||
# base class BaseMessage with a union of its subclasses, which makes
|
||||
# for a much better schema.
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
list[AnyMessage],
|
||||
]
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return self._model(config).invoke(input, config=config, **kwargs)
|
||||
|
||||
@override
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return await self._model(config).ainvoke(input, config=config, **kwargs)
|
||||
|
||||
@override
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Any]:
|
||||
yield from self._model(config).stream(input, config=config, **kwargs)
|
||||
|
||||
@override
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Any]:
|
||||
async for x in self._model(config).astream(input, config=config, **kwargs):
|
||||
yield x
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> list[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
return self._model(config).batch(
|
||||
inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
return super().batch(
|
||||
inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> list[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
return await self._model(config).abatch(
|
||||
inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
return await super().abatch(
|
||||
inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def batch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[tuple[int, Union[Any, Exception]]]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
|
||||
inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
else:
|
||||
yield from super().batch_as_completed( # type: ignore[call-overload]
|
||||
inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def abatch_as_completed(
|
||||
self,
|
||||
inputs: Sequence[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[tuple[int, Any]]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
if isinstance(config, list):
|
||||
config = config[0]
|
||||
async for x in self._model(
|
||||
cast("RunnableConfig", config),
|
||||
).abatch_as_completed( # type: ignore[call-overload]
|
||||
inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
):
|
||||
yield x
|
||||
# If multiple configs default to Runnable.batch which uses executor to invoke
|
||||
# in parallel.
|
||||
else:
|
||||
async for x in super().abatch_as_completed( # type: ignore[call-overload]
|
||||
inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
):
|
||||
yield x
|
||||
|
||||
@override
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[LanguageModelInput],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Any]:
|
||||
yield from self._model(config).transform(input, config=config, **kwargs)
|
||||
|
||||
@override
|
||||
async def atransform(
|
||||
self,
|
||||
input: AsyncIterator[LanguageModelInput],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Any]:
|
||||
async for x in self._model(config).atransform(input, config=config, **kwargs):
|
||||
yield x
|
||||
|
||||
@overload
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: Literal[True] = True,
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[RunLogPatch]: ...
|
||||
|
||||
@overload
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: Literal[False],
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[RunLog]: ...
|
||||
|
||||
@override
|
||||
async def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: bool = True,
|
||||
with_streamed_output_list: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
||||
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
|
||||
input,
|
||||
config=config,
|
||||
diff=diff,
|
||||
with_streamed_output_list=with_streamed_output_list,
|
||||
include_names=include_names,
|
||||
include_types=include_types,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
exclude_types=exclude_types,
|
||||
exclude_names=exclude_names,
|
||||
**kwargs,
|
||||
):
|
||||
yield x
|
||||
|
||||
@override
|
||||
async def astream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
version: Literal["v1", "v2"] = "v2",
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
async for x in self._model(config).astream_events(
|
||||
input,
|
||||
config=config,
|
||||
version=version,
|
||||
include_names=include_names,
|
||||
include_types=include_types,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
exclude_types=exclude_types,
|
||||
exclude_names=exclude_names,
|
||||
**kwargs,
|
||||
):
|
||||
yield x
|
||||
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.__getattr__("bind_tools")(tools, **kwargs)
|
||||
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[dict, type[BaseModel]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
7
libs/langchain_v1/langchain/embeddings/__init__.py
Normal file
7
libs/langchain_v1/langchain/embeddings/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from langchain.embeddings.base import init_embeddings
|
||||
from langchain.embeddings.cache import CacheBackedEmbeddings
|
||||
|
||||
__all__ = [
|
||||
"CacheBackedEmbeddings",
|
||||
"init_embeddings",
|
||||
]
|
235
libs/langchain_v1/langchain/embeddings/base.py
Normal file
235
libs/langchain_v1/langchain/embeddings/base.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import functools
|
||||
from importlib import util
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
_SUPPORTED_PROVIDERS = {
|
||||
"azure_openai": "langchain_openai",
|
||||
"bedrock": "langchain_aws",
|
||||
"cohere": "langchain_cohere",
|
||||
"google_vertexai": "langchain_google_vertexai",
|
||||
"huggingface": "langchain_huggingface",
|
||||
"mistralai": "langchain_mistralai",
|
||||
"ollama": "langchain_ollama",
|
||||
"openai": "langchain_openai",
|
||||
}
|
||||
|
||||
|
||||
def _get_provider_list() -> str:
|
||||
"""Get formatted list of providers and their packages."""
|
||||
return "\n".join(
|
||||
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
|
||||
)
|
||||
|
||||
|
||||
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||
"""Parse a model string into provider and model name components.
|
||||
|
||||
The model string should be in the format 'provider:model-name', where provider
|
||||
is one of the supported providers.
|
||||
|
||||
Args:
|
||||
model_name: A model string in the format 'provider:model-name'
|
||||
|
||||
Returns:
|
||||
A tuple of (provider, model_name)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
_parse_model_string("openai:text-embedding-3-small")
|
||||
# Returns: ("openai", "text-embedding-3-small")
|
||||
|
||||
_parse_model_string("bedrock:amazon.titan-embed-text-v1")
|
||||
# Returns: ("bedrock", "amazon.titan-embed-text-v1")
|
||||
|
||||
Raises:
|
||||
ValueError: If the model string is not in the correct format or
|
||||
the provider is unsupported
|
||||
"""
|
||||
if ":" not in model_name:
|
||||
providers = _SUPPORTED_PROVIDERS
|
||||
msg = (
|
||||
f"Invalid model format '{model_name}'.\n"
|
||||
f"Model name must be in format 'provider:model-name'\n"
|
||||
f"Example valid model strings:\n"
|
||||
f" - openai:text-embedding-3-small\n"
|
||||
f" - bedrock:amazon.titan-embed-text-v1\n"
|
||||
f" - cohere:embed-english-v3.0\n"
|
||||
f"Supported providers: {providers}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
provider, model = model_name.split(":", 1)
|
||||
provider = provider.lower().strip()
|
||||
model = model.strip()
|
||||
|
||||
if provider not in _SUPPORTED_PROVIDERS:
|
||||
msg = (
|
||||
f"Provider '{provider}' is not supported.\n"
|
||||
f"Supported providers and their required packages:\n"
|
||||
f"{_get_provider_list()}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if not model:
|
||||
msg = "Model name cannot be empty"
|
||||
raise ValueError(msg)
|
||||
return provider, model
|
||||
|
||||
|
||||
def _infer_model_and_provider(
|
||||
model: str,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
) -> tuple[str, str]:
|
||||
if not model.strip():
|
||||
msg = "Model name cannot be empty"
|
||||
raise ValueError(msg)
|
||||
if provider is None and ":" in model:
|
||||
provider, model_name = _parse_model_string(model)
|
||||
else:
|
||||
model_name = model
|
||||
|
||||
if not provider:
|
||||
providers = _SUPPORTED_PROVIDERS
|
||||
msg = (
|
||||
"Must specify either:\n"
|
||||
"1. A model string in format 'provider:model-name'\n"
|
||||
" Example: 'openai:text-embedding-3-small'\n"
|
||||
"2. Or explicitly set provider from: "
|
||||
f"{providers}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if provider not in _SUPPORTED_PROVIDERS:
|
||||
msg = (
|
||||
f"Provider '{provider}' is not supported.\n"
|
||||
f"Supported providers and their required packages:\n"
|
||||
f"{_get_provider_list()}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return provider, model_name
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
|
||||
def _check_pkg(pkg: str) -> None:
|
||||
"""Check if a package is installed."""
|
||||
if not util.find_spec(pkg):
|
||||
msg = (
|
||||
f"Could not import {pkg} python package. "
|
||||
f"Please install it with `pip install {pkg}`"
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
def init_embeddings(
|
||||
model: str,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[Embeddings, Runnable[Any, list[float]]]:
|
||||
"""Initialize an embeddings model from a model name and optional provider.
|
||||
|
||||
**Note:** Must have the integration package corresponding to the model provider
|
||||
installed.
|
||||
|
||||
Args:
|
||||
model: Name of the model to use. Can be either:
|
||||
- A model string like "openai:text-embedding-3-small"
|
||||
- Just the model name if provider is specified
|
||||
provider: Optional explicit provider name. If not specified,
|
||||
will attempt to parse from the model string. Supported providers
|
||||
and their required packages:
|
||||
|
||||
{_get_provider_list()}
|
||||
|
||||
**kwargs: Additional model-specific parameters passed to the embedding model.
|
||||
These vary by provider, see the provider-specific documentation for details.
|
||||
|
||||
Returns:
|
||||
An Embeddings instance that can generate embeddings for text.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model provider is not supported or cannot be determined
|
||||
ImportError: If the required provider package is not installed
|
||||
|
||||
.. dropdown:: Example Usage
|
||||
:open:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Using a model string
|
||||
model = init_embeddings("openai:text-embedding-3-small")
|
||||
model.embed_query("Hello, world!")
|
||||
|
||||
# Using explicit provider
|
||||
model = init_embeddings(
|
||||
model="text-embedding-3-small",
|
||||
provider="openai"
|
||||
)
|
||||
model.embed_documents(["Hello, world!", "Goodbye, world!"])
|
||||
|
||||
# With additional parameters
|
||||
model = init_embeddings(
|
||||
"openai:text-embedding-3-small",
|
||||
api_key="sk-..."
|
||||
)
|
||||
|
||||
.. versionadded:: 0.3.9
|
||||
"""
|
||||
if not model:
|
||||
providers = _SUPPORTED_PROVIDERS.keys()
|
||||
msg = (
|
||||
f"Must specify model name. Supported providers are: {', '.join(providers)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
provider, model_name = _infer_model_and_provider(model, provider=provider)
|
||||
pkg = _SUPPORTED_PROVIDERS[provider]
|
||||
_check_pkg(pkg)
|
||||
|
||||
if provider == "openai":
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
return OpenAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "azure_openai":
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
|
||||
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "google_vertexai":
|
||||
from langchain_google_vertexai import VertexAIEmbeddings
|
||||
|
||||
return VertexAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "bedrock":
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
|
||||
return BedrockEmbeddings(model_id=model_name, **kwargs)
|
||||
if provider == "cohere":
|
||||
from langchain_cohere import CohereEmbeddings
|
||||
|
||||
return CohereEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "mistralai":
|
||||
from langchain_mistralai import MistralAIEmbeddings
|
||||
|
||||
return MistralAIEmbeddings(model=model_name, **kwargs)
|
||||
if provider == "huggingface":
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
||||
if provider == "ollama":
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
|
||||
return OllamaEmbeddings(model=model_name, **kwargs)
|
||||
msg = (
|
||||
f"Provider '{provider}' is not supported.\n"
|
||||
f"Supported providers and their required packages:\n"
|
||||
f"{_get_provider_list()}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Embeddings", # This one is for backwards compatibility
|
||||
"init_embeddings",
|
||||
]
|
371
libs/langchain_v1/langchain/embeddings/cache.py
Normal file
371
libs/langchain_v1/langchain/embeddings/cache.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Module contains code for a cache backed embedder.
|
||||
|
||||
The cache backed embedder is a wrapper around an embedder that caches
|
||||
embeddings in a key-value store. The cache is used to avoid recomputing
|
||||
embeddings for the same text.
|
||||
|
||||
The text is hashed and the hash is used as the key in the cache.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union, cast
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
|
||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
|
||||
NAMESPACE_UUID = uuid.UUID(int=1985)
|
||||
|
||||
|
||||
def _sha1_hash_to_uuid(text: str) -> uuid.UUID:
|
||||
"""Return a UUID derived from *text* using SHA-1 (deterministic).
|
||||
|
||||
Deterministic and fast, **but not collision-resistant**.
|
||||
|
||||
A malicious attacker could try to create two different texts that hash to the same
|
||||
UUID. This may not necessarily be an issue in the context of caching embeddings,
|
||||
but new applications should swap this out for a stronger hash function like
|
||||
xxHash, BLAKE2 or SHA-256, which are collision-resistant.
|
||||
"""
|
||||
sha1_hex = hashlib.sha1(text.encode("utf-8"), usedforsecurity=False).hexdigest()
|
||||
# Embed the hex string in `uuid5` to obtain a valid UUID.
|
||||
return uuid.uuid5(NAMESPACE_UUID, sha1_hex)
|
||||
|
||||
|
||||
def _make_default_key_encoder(namespace: str, algorithm: str) -> Callable[[str], str]:
|
||||
"""Create a default key encoder function.
|
||||
|
||||
Args:
|
||||
namespace: Prefix that segregates keys from different embedding models.
|
||||
algorithm:
|
||||
* ``'sha1'`` - fast but not collision-resistant
|
||||
* ``'blake2b'`` - cryptographically strong, faster than SHA-1
|
||||
* ``'sha256'`` - cryptographically strong, slower than SHA-1
|
||||
* ``'sha512'`` - cryptographically strong, slower than SHA-1
|
||||
|
||||
Returns:
|
||||
A function that encodes a key using the specified algorithm.
|
||||
"""
|
||||
if algorithm == "sha1":
|
||||
_warn_about_sha1_encoder()
|
||||
|
||||
def _key_encoder(key: str) -> str:
|
||||
"""Encode a key using the specified algorithm."""
|
||||
if algorithm == "sha1":
|
||||
return f"{namespace}{_sha1_hash_to_uuid(key)}"
|
||||
if algorithm == "blake2b":
|
||||
return f"{namespace}{hashlib.blake2b(key.encode('utf-8')).hexdigest()}"
|
||||
if algorithm == "sha256":
|
||||
return f"{namespace}{hashlib.sha256(key.encode('utf-8')).hexdigest()}"
|
||||
if algorithm == "sha512":
|
||||
return f"{namespace}{hashlib.sha512(key.encode('utf-8')).hexdigest()}"
|
||||
msg = f"Unsupported algorithm: {algorithm}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return _key_encoder
|
||||
|
||||
|
||||
def _value_serializer(value: Sequence[float]) -> bytes:
|
||||
"""Serialize a value."""
|
||||
return json.dumps(value).encode()
|
||||
|
||||
|
||||
def _value_deserializer(serialized_value: bytes) -> list[float]:
|
||||
"""Deserialize a value."""
|
||||
return cast("list[float]", json.loads(serialized_value.decode()))
|
||||
|
||||
|
||||
# The warning is global; track emission, so it appears only once.
|
||||
_warned_about_sha1: bool = False
|
||||
|
||||
|
||||
def _warn_about_sha1_encoder() -> None:
|
||||
"""Emit a one-time warning about SHA-1 collision weaknesses."""
|
||||
global _warned_about_sha1 # noqa: PLW0603
|
||||
if not _warned_about_sha1:
|
||||
warnings.warn(
|
||||
"Using default key encoder: SHA-1 is *not* collision-resistant. "
|
||||
"While acceptable for most cache scenarios, a motivated attacker "
|
||||
"can craft two different payloads that map to the same cache key. "
|
||||
"If that risk matters in your environment, supply a stronger "
|
||||
"encoder (e.g. SHA-256 or BLAKE2) via the `key_encoder` argument. "
|
||||
"If you change the key encoder, consider also creating a new cache, "
|
||||
"to avoid (the potential for) collisions with existing keys.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_warned_about_sha1 = True
|
||||
|
||||
|
||||
class CacheBackedEmbeddings(Embeddings):
|
||||
"""Interface for caching results from embedding models.
|
||||
|
||||
The interface allows works with any store that implements
|
||||
the abstract store interface accepting keys of type str and values of list of
|
||||
floats.
|
||||
|
||||
If need be, the interface can be extended to accept other implementations
|
||||
of the value serializer and deserializer, as well as the key encoder.
|
||||
|
||||
Note that by default only document embeddings are cached. To cache query
|
||||
embeddings too, pass in a query_embedding_store to constructor.
|
||||
|
||||
Examples:
|
||||
.. code-block: python
|
||||
|
||||
from langchain.embeddings import CacheBackedEmbeddings
|
||||
from langchain.storage import LocalFileStore
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
|
||||
store = LocalFileStore('./my_cache')
|
||||
|
||||
underlying_embedder = OpenAIEmbeddings()
|
||||
embedder = CacheBackedEmbeddings.from_bytes_store(
|
||||
underlying_embedder, store, namespace=underlying_embedder.model
|
||||
)
|
||||
|
||||
# Embedding is computed and cached
|
||||
embeddings = embedder.embed_documents(["hello", "goodbye"])
|
||||
|
||||
# Embeddings are retrieved from the cache, no computation is done
|
||||
embeddings = embedder.embed_documents(["hello", "goodbye"])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
underlying_embeddings: Embeddings,
|
||||
document_embedding_store: BaseStore[str, list[float]],
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
|
||||
) -> None:
|
||||
"""Initialize the embedder.
|
||||
|
||||
Args:
|
||||
underlying_embeddings: the embedder to use for computing embeddings.
|
||||
document_embedding_store: The store to use for caching document embeddings.
|
||||
batch_size: The number of documents to embed between store updates.
|
||||
query_embedding_store: The store to use for caching query embeddings.
|
||||
If ``None``, query embeddings are not cached.
|
||||
"""
|
||||
super().__init__()
|
||||
self.document_embedding_store = document_embedding_store
|
||||
self.query_embedding_store = query_embedding_store
|
||||
self.underlying_embeddings = underlying_embeddings
|
||||
self.batch_size = batch_size
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of texts.
|
||||
|
||||
The method first checks the cache for the embeddings.
|
||||
If the embeddings are not found, the method uses the underlying embedder
|
||||
to embed the documents and stores the results in the cache.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to embed.
|
||||
|
||||
Returns:
|
||||
A list of embeddings for the given texts.
|
||||
"""
|
||||
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
|
||||
texts,
|
||||
)
|
||||
all_missing_indices: list[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
|
||||
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
missing_vectors = self.underlying_embeddings.embed_documents(missing_texts)
|
||||
self.document_embedding_store.mset(
|
||||
list(zip(missing_texts, missing_vectors)),
|
||||
)
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
"list[list[float]]",
|
||||
vectors,
|
||||
) # Nones should have been resolved by now
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of texts.
|
||||
|
||||
The method first checks the cache for the embeddings.
|
||||
If the embeddings are not found, the method uses the underlying embedder
|
||||
to embed the documents and stores the results in the cache.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to embed.
|
||||
|
||||
Returns:
|
||||
A list of embeddings for the given texts.
|
||||
"""
|
||||
vectors: list[
|
||||
Union[list[float], None]
|
||||
] = await self.document_embedding_store.amget(texts)
|
||||
all_missing_indices: list[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
|
||||
# batch_iterate supports None batch_size which returns all elements at once
|
||||
# as a single batch.
|
||||
for missing_indices in batch_iterate(self.batch_size, all_missing_indices):
|
||||
missing_texts = [texts[i] for i in missing_indices]
|
||||
missing_vectors = await self.underlying_embeddings.aembed_documents(
|
||||
missing_texts,
|
||||
)
|
||||
await self.document_embedding_store.amset(
|
||||
list(zip(missing_texts, missing_vectors)),
|
||||
)
|
||||
for index, updated_vector in zip(missing_indices, missing_vectors):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
"list[list[float]]",
|
||||
vectors,
|
||||
) # Nones should have been resolved by now
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text.
|
||||
|
||||
By default, this method does not cache queries. To enable caching, set the
|
||||
``cache_query`` parameter to ``True`` when initializing the embedder.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
The embedding for the given text.
|
||||
"""
|
||||
if not self.query_embedding_store:
|
||||
return self.underlying_embeddings.embed_query(text)
|
||||
|
||||
(cached,) = self.query_embedding_store.mget([text])
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
vector = self.underlying_embeddings.embed_query(text)
|
||||
self.query_embedding_store.mset([(text, vector)])
|
||||
return vector
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text.
|
||||
|
||||
By default, this method does not cache queries. To enable caching, set the
|
||||
``cache_query`` parameter to ``True`` when initializing the embedder.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
The embedding for the given text.
|
||||
"""
|
||||
if not self.query_embedding_store:
|
||||
return await self.underlying_embeddings.aembed_query(text)
|
||||
|
||||
(cached,) = await self.query_embedding_store.amget([text])
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
vector = await self.underlying_embeddings.aembed_query(text)
|
||||
await self.query_embedding_store.amset([(text, vector)])
|
||||
return vector
|
||||
|
||||
@classmethod
|
||||
def from_bytes_store(
|
||||
cls,
|
||||
underlying_embeddings: Embeddings,
|
||||
document_embedding_cache: ByteStore,
|
||||
*,
|
||||
namespace: str = "",
|
||||
batch_size: Optional[int] = None,
|
||||
query_embedding_cache: Union[bool, ByteStore] = False,
|
||||
key_encoder: Union[
|
||||
Callable[[str], str],
|
||||
Literal["sha1", "blake2b", "sha256", "sha512"],
|
||||
] = "sha1",
|
||||
) -> CacheBackedEmbeddings:
|
||||
"""On-ramp that adds the necessary serialization and encoding to the store.
|
||||
|
||||
Args:
|
||||
underlying_embeddings: The embedder to use for embedding.
|
||||
document_embedding_cache: The cache to use for storing document embeddings.
|
||||
*,
|
||||
namespace: The namespace to use for document cache.
|
||||
This namespace is used to avoid collisions with other caches.
|
||||
For example, set it to the name of the embedding model used.
|
||||
batch_size: The number of documents to embed between store updates.
|
||||
query_embedding_cache: The cache to use for storing query embeddings.
|
||||
True to use the same cache as document embeddings.
|
||||
False to not cache query embeddings.
|
||||
key_encoder: Optional callable to encode keys. If not provided,
|
||||
a default encoder using SHA-1 will be used. SHA-1 is not
|
||||
collision-resistant, and a motivated attacker could craft two
|
||||
different texts that hash to the same cache key.
|
||||
|
||||
New applications should use one of the alternative encoders
|
||||
or provide a custom and strong key encoder function to avoid this risk.
|
||||
|
||||
If you change a key encoder in an existing cache, consider
|
||||
just creating a new cache, to avoid (the potential for)
|
||||
collisions with existing keys or having duplicate keys
|
||||
for the same text in the cache.
|
||||
|
||||
Returns:
|
||||
An instance of CacheBackedEmbeddings that uses the provided cache.
|
||||
"""
|
||||
if isinstance(key_encoder, str):
|
||||
key_encoder = _make_default_key_encoder(namespace, key_encoder)
|
||||
elif callable(key_encoder):
|
||||
# If a custom key encoder is provided, it should not be used with a
|
||||
# namespace.
|
||||
# A user can handle namespacing in directly their custom key encoder.
|
||||
if namespace:
|
||||
msg = (
|
||||
"Do not supply `namespace` when using a custom key_encoder; "
|
||||
"add any prefixing inside the encoder itself."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = (
|
||||
"key_encoder must be either 'blake2b', 'sha1', 'sha256', 'sha512' "
|
||||
"or a callable that encodes keys."
|
||||
)
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
|
||||
document_embedding_store = EncoderBackedStore[str, list[float]](
|
||||
document_embedding_cache,
|
||||
key_encoder,
|
||||
_value_serializer,
|
||||
_value_deserializer,
|
||||
)
|
||||
if query_embedding_cache is True:
|
||||
query_embedding_store = document_embedding_store
|
||||
elif query_embedding_cache is False:
|
||||
query_embedding_store = None
|
||||
else:
|
||||
query_embedding_store = EncoderBackedStore[str, list[float]](
|
||||
query_embedding_cache,
|
||||
key_encoder,
|
||||
_value_serializer,
|
||||
_value_deserializer,
|
||||
)
|
||||
|
||||
return cls(
|
||||
underlying_embeddings,
|
||||
document_embedding_store,
|
||||
batch_size=batch_size,
|
||||
query_embedding_store=query_embedding_store,
|
||||
)
|
15
libs/langchain_v1/langchain/globals.py
Normal file
15
libs/langchain_v1/langchain/globals.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Global values and configuration that apply to all of LangChain."""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.caches import BaseCache
|
||||
|
||||
|
||||
# DO NOT USE THESE VALUES DIRECTLY!
|
||||
# Use them only via `get_<X>()` and `set_<X>()` below,
|
||||
# or else your code may behave unexpectedly with other uses of these global settings:
|
||||
# https://github.com/langchain-ai/langchain/pull/11311#issuecomment-1743780004
|
||||
_verbose: bool = False
|
||||
_debug: bool = False
|
||||
_llm_cache: Optional["BaseCache"] = None
|
16
libs/langchain_v1/langchain/memory/__init__.py
Normal file
16
libs/langchain_v1/langchain/memory/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""TBD: This module should provide high level building blocks for memory management.
|
||||
|
||||
We may want to wait until we combine:
|
||||
|
||||
1. langmem
|
||||
2. some basic functions for message summarization
|
||||
"""
|
||||
|
||||
from langchain_core.messages import filter_messages, trim_messages
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
|
||||
__all__ = [
|
||||
"count_tokens_approximately",
|
||||
"filter_messages",
|
||||
"trim_messages",
|
||||
]
|
39
libs/langchain_v1/langchain/prompts/__init__.py
Normal file
39
libs/langchain_v1/langchain/prompts/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from langchain_core.example_selectors import (
|
||||
LengthBasedExampleSelector,
|
||||
MaxMarginalRelevanceExampleSelector,
|
||||
SemanticSimilarityExampleSelector,
|
||||
)
|
||||
from langchain_core.prompts import (
|
||||
AIMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
BasePromptTemplate,
|
||||
ChatMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
FewShotChatMessagePromptTemplate,
|
||||
FewShotPromptTemplate,
|
||||
FewShotPromptWithTemplates,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
PromptTemplate,
|
||||
StringPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AIMessagePromptTemplate",
|
||||
"BaseChatPromptTemplate",
|
||||
"BasePromptTemplate",
|
||||
"ChatMessagePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
"FewShotPromptTemplate",
|
||||
"FewShotPromptWithTemplates",
|
||||
"HumanMessagePromptTemplate",
|
||||
"LengthBasedExampleSelector",
|
||||
"MaxMarginalRelevanceExampleSelector",
|
||||
"MessagesPlaceholder",
|
||||
"PromptTemplate",
|
||||
"SemanticSimilarityExampleSelector",
|
||||
"StringPromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
]
|
0
libs/langchain_v1/langchain/py.typed
Normal file
0
libs/langchain_v1/langchain/py.typed
Normal file
27
libs/langchain_v1/langchain/storage/__init__.py
Normal file
27
libs/langchain_v1/langchain/storage/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Implementations of key-value stores and storage helpers.
|
||||
|
||||
Module provides implementations of various key-value stores that conform
|
||||
to a simple key-value interface.
|
||||
|
||||
The primary goal of these storages is to support implementation of caching.
|
||||
"""
|
||||
|
||||
from langchain_core.stores import (
|
||||
InMemoryByteStore,
|
||||
InMemoryStore,
|
||||
InvalidKeyException,
|
||||
)
|
||||
|
||||
from langchain.storage._lc_store import create_kv_docstore, create_lc_store
|
||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
from langchain.storage.file_system import LocalFileStore
|
||||
|
||||
__all__ = [
|
||||
"EncoderBackedStore",
|
||||
"InMemoryByteStore",
|
||||
"InMemoryStore",
|
||||
"InvalidKeyException",
|
||||
"LocalFileStore",
|
||||
"create_kv_docstore",
|
||||
"create_lc_store",
|
||||
]
|
91
libs/langchain_v1/langchain/storage/_lc_store.py
Normal file
91
libs/langchain_v1/langchain/storage/_lc_store.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Create a key-value store for any langchain serializable object."""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load import Serializable, dumps, loads
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
|
||||
from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
|
||||
|
||||
def _dump_as_bytes(obj: Serializable) -> bytes:
|
||||
"""Return a bytes representation of a document."""
|
||||
return dumps(obj).encode("utf-8")
|
||||
|
||||
|
||||
def _dump_document_as_bytes(obj: Document) -> bytes:
|
||||
"""Return a bytes representation of a document."""
|
||||
if not isinstance(obj, Document):
|
||||
msg = "Expected a Document instance"
|
||||
raise TypeError(msg)
|
||||
return dumps(obj).encode("utf-8")
|
||||
|
||||
|
||||
def _load_document_from_bytes(serialized: bytes) -> Document:
|
||||
"""Return a document from a bytes representation."""
|
||||
obj = loads(serialized.decode("utf-8"))
|
||||
if not isinstance(obj, Document):
|
||||
msg = f"Expected a Document instance. Got {type(obj)}"
|
||||
raise TypeError(msg)
|
||||
return obj
|
||||
|
||||
|
||||
def _load_from_bytes(serialized: bytes) -> Serializable:
|
||||
"""Return a document from a bytes representation."""
|
||||
return loads(serialized.decode("utf-8"))
|
||||
|
||||
|
||||
def _identity(x: str) -> str:
|
||||
"""Return the same object."""
|
||||
return x
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def create_lc_store(
|
||||
store: ByteStore,
|
||||
*,
|
||||
key_encoder: Optional[Callable[[str], str]] = None,
|
||||
) -> BaseStore[str, Serializable]:
|
||||
"""Create a store for langchain serializable objects from a bytes store.
|
||||
|
||||
Args:
|
||||
store: A bytes store to use as the underlying store.
|
||||
key_encoder: A function to encode keys; if None uses identity function.
|
||||
|
||||
Returns:
|
||||
A key-value store for documents.
|
||||
"""
|
||||
return EncoderBackedStore(
|
||||
store,
|
||||
key_encoder or _identity,
|
||||
_dump_as_bytes,
|
||||
_load_from_bytes,
|
||||
)
|
||||
|
||||
|
||||
def create_kv_docstore(
|
||||
store: ByteStore,
|
||||
*,
|
||||
key_encoder: Optional[Callable[[str], str]] = None,
|
||||
) -> BaseStore[str, Document]:
|
||||
"""Create a store for langchain Document objects from a bytes store.
|
||||
|
||||
This store does run time type checking to ensure that the values are
|
||||
Document objects.
|
||||
|
||||
Args:
|
||||
store: A bytes store to use as the underlying store.
|
||||
key_encoder: A function to encode keys; if None uses identity function.
|
||||
|
||||
Returns:
|
||||
A key-value store for documents.
|
||||
"""
|
||||
return EncoderBackedStore(
|
||||
store,
|
||||
key_encoder or _identity,
|
||||
_dump_document_as_bytes,
|
||||
_load_document_from_bytes,
|
||||
)
|
127
libs/langchain_v1/langchain/storage/encoder_backed.py
Normal file
127
libs/langchain_v1/langchain/storage/encoder_backed.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class EncoderBackedStore(BaseStore[K, V]):
|
||||
"""Wraps a store with key and value encoders/decoders.
|
||||
|
||||
Examples that uses JSON for encoding/decoding:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import json
|
||||
|
||||
def key_encoder(key: int) -> str:
|
||||
return json.dumps(key)
|
||||
|
||||
def value_serializer(value: float) -> str:
|
||||
return json.dumps(value)
|
||||
|
||||
def value_deserializer(serialized_value: str) -> float:
|
||||
return json.loads(serialized_value)
|
||||
|
||||
# Create an instance of the abstract store
|
||||
abstract_store = MyCustomStore()
|
||||
|
||||
# Create an instance of the encoder-backed store
|
||||
store = EncoderBackedStore(
|
||||
store=abstract_store,
|
||||
key_encoder=key_encoder,
|
||||
value_serializer=value_serializer,
|
||||
value_deserializer=value_deserializer
|
||||
)
|
||||
|
||||
# Use the encoder-backed store methods
|
||||
store.mset([(1, 3.14), (2, 2.718)])
|
||||
values = store.mget([1, 2]) # Retrieves [3.14, 2.718]
|
||||
store.mdelete([1, 2]) # Deletes the keys 1 and 2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: BaseStore[str, Any],
|
||||
key_encoder: Callable[[K], str],
|
||||
value_serializer: Callable[[V], bytes],
|
||||
value_deserializer: Callable[[Any], V],
|
||||
) -> None:
|
||||
"""Initialize an EncodedStore."""
|
||||
self.store = store
|
||||
self.key_encoder = key_encoder
|
||||
self.value_serializer = value_serializer
|
||||
self.value_deserializer = value_deserializer
|
||||
|
||||
def mget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||
"""Get the values associated with the given keys."""
|
||||
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
||||
values = self.store.mget(encoded_keys)
|
||||
return [
|
||||
self.value_deserializer(value) if value is not None else value
|
||||
for value in values
|
||||
]
|
||||
|
||||
async def amget(self, keys: Sequence[K]) -> list[Optional[V]]:
|
||||
"""Get the values associated with the given keys."""
|
||||
encoded_keys: list[str] = [self.key_encoder(key) for key in keys]
|
||||
values = await self.store.amget(encoded_keys)
|
||||
return [
|
||||
self.value_deserializer(value) if value is not None else value
|
||||
for value in values
|
||||
]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys."""
|
||||
encoded_pairs = [
|
||||
(self.key_encoder(key), self.value_serializer(value))
|
||||
for key, value in key_value_pairs
|
||||
]
|
||||
self.store.mset(encoded_pairs)
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys."""
|
||||
encoded_pairs = [
|
||||
(self.key_encoder(key), self.value_serializer(value))
|
||||
for key, value in key_value_pairs
|
||||
]
|
||||
await self.store.amset(encoded_pairs)
|
||||
|
||||
def mdelete(self, keys: Sequence[K]) -> None:
|
||||
"""Delete the given keys and their associated values."""
|
||||
encoded_keys = [self.key_encoder(key) for key in keys]
|
||||
self.store.mdelete(encoded_keys)
|
||||
|
||||
async def amdelete(self, keys: Sequence[K]) -> None:
|
||||
"""Delete the given keys and their associated values."""
|
||||
encoded_keys = [self.key_encoder(key) for key in keys]
|
||||
await self.store.amdelete(encoded_keys)
|
||||
|
||||
def yield_keys(
|
||||
self,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
) -> Union[Iterator[K], Iterator[str]]:
|
||||
"""Get an iterator over keys that match the given prefix."""
|
||||
# For the time being this does not return K, but str
|
||||
# it's for debugging purposes. Should fix this.
|
||||
yield from self.store.yield_keys(prefix=prefix)
|
||||
|
||||
async def ayield_keys(
|
||||
self,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
|
||||
"""Get an iterator over keys that match the given prefix."""
|
||||
# For the time being this does not return K, but str
|
||||
# it's for debugging purposes. Should fix this.
|
||||
async for key in self.store.ayield_keys(prefix=prefix):
|
||||
yield key
|
3
libs/langchain_v1/langchain/storage/exceptions.py
Normal file
3
libs/langchain_v1/langchain/storage/exceptions.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from langchain_core.stores import InvalidKeyException
|
||||
|
||||
__all__ = ["InvalidKeyException"]
|
176
libs/langchain_v1/langchain/storage/file_system.py
Normal file
176
libs/langchain_v1/langchain/storage/file_system.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Iterator, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from langchain_core.stores import ByteStore
|
||||
|
||||
from langchain.storage.exceptions import InvalidKeyException
|
||||
|
||||
|
||||
class LocalFileStore(ByteStore):
|
||||
"""BaseStore interface that works on the local file system.
|
||||
|
||||
Examples:
|
||||
Create a LocalFileStore instance and perform operations on it:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.storage import LocalFileStore
|
||||
|
||||
# Instantiate the LocalFileStore with the root path
|
||||
file_store = LocalFileStore("/path/to/root")
|
||||
|
||||
# Set values for keys
|
||||
file_store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||
|
||||
# Get values for keys
|
||||
values = file_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
|
||||
|
||||
# Delete keys
|
||||
file_store.mdelete(["key1"])
|
||||
|
||||
# Iterate over keys
|
||||
for key in file_store.yield_keys():
|
||||
print(key) # noqa: T201
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_path: Union[str, Path],
|
||||
*,
|
||||
chmod_file: Optional[int] = None,
|
||||
chmod_dir: Optional[int] = None,
|
||||
update_atime: bool = False,
|
||||
) -> None:
|
||||
"""Implement the BaseStore interface for the local file system.
|
||||
|
||||
Args:
|
||||
root_path (Union[str, Path]): The root path of the file store. All keys are
|
||||
interpreted as paths relative to this root.
|
||||
chmod_file: (optional, defaults to `None`) If specified, sets permissions
|
||||
for newly created files, overriding the current `umask` if needed.
|
||||
chmod_dir: (optional, defaults to `None`) If specified, sets permissions
|
||||
for newly created dirs, overriding the current `umask` if needed.
|
||||
update_atime: (optional, defaults to `False`) If `True`, updates the
|
||||
filesystem access time (but not the modified time) when a file is read.
|
||||
This allows MRU/LRU cache policies to be implemented for filesystems
|
||||
where access time updates are disabled.
|
||||
"""
|
||||
self.root_path = Path(root_path).absolute()
|
||||
self.chmod_file = chmod_file
|
||||
self.chmod_dir = chmod_dir
|
||||
self.update_atime = update_atime
|
||||
|
||||
def _get_full_path(self, key: str) -> Path:
|
||||
"""Get the full path for a given key relative to the root path.
|
||||
|
||||
Args:
|
||||
key (str): The key relative to the root path.
|
||||
|
||||
Returns:
|
||||
Path: The full path for the given key.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_.\-/]+$", key):
|
||||
msg = f"Invalid characters in key: {key}"
|
||||
raise InvalidKeyException(msg)
|
||||
full_path = (self.root_path / key).resolve()
|
||||
root_path = self.root_path.resolve()
|
||||
common_path = os.path.commonpath([root_path, full_path])
|
||||
if common_path != str(root_path):
|
||||
msg = (
|
||||
f"Invalid key: {key}. Key should be relative to the full path. "
|
||||
f"{root_path} vs. {common_path} and full path of {full_path}"
|
||||
)
|
||||
raise InvalidKeyException(msg)
|
||||
|
||||
return full_path
|
||||
|
||||
def _mkdir_for_store(self, dir_path: Path) -> None:
|
||||
"""Makes a store directory path (including parents) with specified permissions.
|
||||
|
||||
This is needed because `Path.mkdir()` is restricted by the current `umask`,
|
||||
whereas the explicit `os.chmod()` used here is not.
|
||||
|
||||
Args:
|
||||
dir_path: (Path) The store directory to make
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not dir_path.exists():
|
||||
self._mkdir_for_store(dir_path.parent)
|
||||
dir_path.mkdir(exist_ok=True)
|
||||
if self.chmod_dir is not None:
|
||||
dir_path.chmod(self.chmod_dir)
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> list[Optional[bytes]]:
|
||||
"""Get the values associated with the given keys.
|
||||
|
||||
Args:
|
||||
keys: A sequence of keys.
|
||||
|
||||
Returns:
|
||||
A sequence of optional values associated with the keys.
|
||||
If a key is not found, the corresponding value will be None.
|
||||
"""
|
||||
values: list[Optional[bytes]] = []
|
||||
for key in keys:
|
||||
full_path = self._get_full_path(key)
|
||||
if full_path.exists():
|
||||
value = full_path.read_bytes()
|
||||
values.append(value)
|
||||
if self.update_atime:
|
||||
# update access time only; preserve modified time
|
||||
os.utime(full_path, (time.time(), full_path.stat().st_mtime))
|
||||
else:
|
||||
values.append(None)
|
||||
return values
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[tuple[str, bytes]]) -> None:
|
||||
"""Set the values for the given keys.
|
||||
|
||||
Args:
|
||||
key_value_pairs: A sequence of key-value pairs.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for key, value in key_value_pairs:
|
||||
full_path = self._get_full_path(key)
|
||||
self._mkdir_for_store(full_path.parent)
|
||||
full_path.write_bytes(value)
|
||||
if self.chmod_file is not None:
|
||||
full_path.chmod(self.chmod_file)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given keys and their associated values.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): A sequence of keys to delete.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for key in keys:
|
||||
full_path = self._get_full_path(key)
|
||||
if full_path.exists():
|
||||
full_path.unlink()
|
||||
|
||||
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Get an iterator over keys that match the given prefix.
|
||||
|
||||
Args:
|
||||
prefix (Optional[str]): The prefix to match.
|
||||
|
||||
Returns:
|
||||
Iterator[str]: An iterator over keys that match the given prefix.
|
||||
"""
|
||||
prefix_path = self._get_full_path(prefix) if prefix else self.root_path
|
||||
for file in prefix_path.rglob("*"):
|
||||
if file.is_file():
|
||||
relative_path = file.relative_to(self.root_path)
|
||||
yield str(relative_path)
|
13
libs/langchain_v1/langchain/storage/in_memory.py
Normal file
13
libs/langchain_v1/langchain/storage/in_memory.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""In memory store that is not thread safe and has no eviction policy.
|
||||
|
||||
This is a simple implementation of the BaseStore using a dictionary that is useful
|
||||
primarily for unit testing purposes.
|
||||
"""
|
||||
|
||||
from langchain_core.stores import InMemoryBaseStore, InMemoryByteStore, InMemoryStore
|
||||
|
||||
__all__ = [
|
||||
"InMemoryBaseStore",
|
||||
"InMemoryByteStore",
|
||||
"InMemoryStore",
|
||||
]
|
50
libs/langchain_v1/langchain/text_splitter.py
Normal file
50
libs/langchain_v1/langchain/text_splitter.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Kept for backwards compatibility."""
|
||||
|
||||
from langchain_text_splitters import (
|
||||
Language,
|
||||
RecursiveCharacterTextSplitter,
|
||||
TextSplitter,
|
||||
Tokenizer,
|
||||
TokenTextSplitter,
|
||||
)
|
||||
from langchain_text_splitters.base import split_text_on_tokens
|
||||
from langchain_text_splitters.character import CharacterTextSplitter
|
||||
from langchain_text_splitters.html import ElementType, HTMLHeaderTextSplitter
|
||||
from langchain_text_splitters.json import RecursiveJsonSplitter
|
||||
from langchain_text_splitters.konlpy import KonlpyTextSplitter
|
||||
from langchain_text_splitters.latex import LatexTextSplitter
|
||||
from langchain_text_splitters.markdown import (
|
||||
HeaderType,
|
||||
LineType,
|
||||
MarkdownHeaderTextSplitter,
|
||||
MarkdownTextSplitter,
|
||||
)
|
||||
from langchain_text_splitters.nltk import NLTKTextSplitter
|
||||
from langchain_text_splitters.python import PythonCodeTextSplitter
|
||||
from langchain_text_splitters.sentence_transformers import (
|
||||
SentenceTransformersTokenTextSplitter,
|
||||
)
|
||||
from langchain_text_splitters.spacy import SpacyTextSplitter
|
||||
|
||||
__all__ = [
|
||||
"CharacterTextSplitter",
|
||||
"ElementType",
|
||||
"HTMLHeaderTextSplitter",
|
||||
"HeaderType",
|
||||
"KonlpyTextSplitter",
|
||||
"Language",
|
||||
"LatexTextSplitter",
|
||||
"LineType",
|
||||
"MarkdownHeaderTextSplitter",
|
||||
"MarkdownTextSplitter",
|
||||
"NLTKTextSplitter",
|
||||
"PythonCodeTextSplitter",
|
||||
"RecursiveCharacterTextSplitter",
|
||||
"RecursiveJsonSplitter",
|
||||
"SentenceTransformersTokenTextSplitter",
|
||||
"SpacyTextSplitter",
|
||||
"TextSplitter",
|
||||
"TokenTextSplitter",
|
||||
"Tokenizer",
|
||||
"split_text_on_tokens",
|
||||
]
|
17
libs/langchain_v1/langchain/tools/__init__.py
Normal file
17
libs/langchain_v1/langchain/tools/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
InjectedToolArg,
|
||||
InjectedToolCallId,
|
||||
Tool,
|
||||
ToolException,
|
||||
tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"InjectedToolArg",
|
||||
"InjectedToolCallId",
|
||||
"Tool",
|
||||
"ToolException",
|
||||
"tool",
|
||||
]
|
190
libs/langchain_v1/pyproject.toml
Normal file
190
libs/langchain_v1/pyproject.toml
Normal file
@@ -0,0 +1,190 @@
|
||||
[build-system]
|
||||
requires = ["pdm-backend"]
|
||||
build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
authors = []
|
||||
license = { text = "MIT" }
|
||||
requires-python = ">=3.9, <4.0"
|
||||
dependencies = [
|
||||
"langchain-core<1.0.0,>=0.3.66",
|
||||
"langchain-text-splitters<1.0.0,>=0.3.8",
|
||||
"langgraph>=0.5.4",
|
||||
"pydantic>=2.7.4",
|
||||
]
|
||||
|
||||
|
||||
name = "langchain"
|
||||
version = "1.0.0dev1"
|
||||
description = "Building applications with LLMs through composability"
|
||||
readme = "README.md"
|
||||
|
||||
[project.optional-dependencies]
|
||||
# community = ["langchain-community"]
|
||||
anthropic = ["langchain-anthropic"]
|
||||
openai = ["langchain-openai"]
|
||||
azure-ai = ["langchain-azure-ai"]
|
||||
# cohere = ["langchain-cohere"]
|
||||
google-vertexai = ["langchain-google-vertexai"]
|
||||
google-genai = ["langchain-google-genai"]
|
||||
fireworks = ["langchain-fireworks"]
|
||||
ollama = ["langchain-ollama"]
|
||||
together = ["langchain-together"]
|
||||
mistralai = ["langchain-mistralai"]
|
||||
huggingface = ["langchain-huggingface"]
|
||||
groq = ["langchain-groq"]
|
||||
aws = ["langchain-aws"]
|
||||
deepseek = ["langchain-deepseek"]
|
||||
xai = ["langchain-xai"]
|
||||
perplexity = ["langchain-perplexity"]
|
||||
|
||||
[project.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/langchain"
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain%3D%3D0%22&expanded=true"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
|
||||
[dependency-groups]
|
||||
test = [
|
||||
"pytest<9,>=8",
|
||||
"pytest-cov<5.0.0,>=4.0.0",
|
||||
"pytest-watcher<1.0.0,>=0.2.6",
|
||||
"pytest-asyncio<1.0.0,>=0.23.2",
|
||||
"pytest-socket<1.0.0,>=0.6.0",
|
||||
"syrupy<5.0.0,>=4.0.2",
|
||||
"pytest-xdist<4.0.0,>=3.6.1",
|
||||
"blockbuster<1.6,>=1.5.18",
|
||||
"langchain-tests",
|
||||
"langchain-core",
|
||||
"langchain-text-splitters",
|
||||
"langchain-openai",
|
||||
"toml>=0.10.2",
|
||||
]
|
||||
codespell = ["codespell<3.0.0,>=2.2.0"]
|
||||
lint = [
|
||||
"ruff<0.13,>=0.12.2",
|
||||
"mypy<1.16,>=1.15",
|
||||
]
|
||||
typing = [
|
||||
"types-toml>=0.10.8.20240310",
|
||||
]
|
||||
|
||||
test_integration = [
|
||||
"vcrpy>=7.0",
|
||||
"urllib3<2; python_version < \"3.10\"",
|
||||
"wrapt<2.0.0,>=1.15.0",
|
||||
"python-dotenv<2.0.0,>=1.0.0",
|
||||
"cassio<1.0.0,>=0.1.0",
|
||||
"langchainhub<1.0.0,>=0.1.16",
|
||||
"langchain-core",
|
||||
"langchain-text-splitters",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../core", editable = true }
|
||||
langchain-tests = { path = "../standard-tests", editable = true }
|
||||
langchain-text-splitters = { path = "../text-splitters", editable = true }
|
||||
langchain-openai = { path = "../partners/openai", editable = true }
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
exclude = ["tests/integration_tests/examples/non-utf8-encoding.py"]
|
||||
|
||||
[tool.mypy]
|
||||
strict = "True"
|
||||
strict_bytes = "True"
|
||||
ignore_missing_imports = "True"
|
||||
enable_error_code = "deprecated"
|
||||
report_deprecated_as_note = "True"
|
||||
|
||||
# TODO: activate for 'strict' checking
|
||||
disallow_untyped_calls = "False"
|
||||
disallow_any_generics = "False"
|
||||
disallow_untyped_decorators = "False"
|
||||
warn_return_any = "False"
|
||||
strict_equality = "False"
|
||||
|
||||
[tool.codespell]
|
||||
skip = ".git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,*.trig"
|
||||
ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"ALL"
|
||||
]
|
||||
ignore = [
|
||||
"D100", # pydocstyle: Missing docstring in public module
|
||||
"D104", # pydocstyle: Missing docstring in public package
|
||||
"D105", # pydocstyle: Missing docstring in magic method
|
||||
"COM812", # Messes with the formatter
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"SLF001", # Private member access
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
"PLC0415", # Imports should be at the top. Not always desirable
|
||||
"PLR0913", # Too many arguments in function definition
|
||||
]
|
||||
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
||||
|
||||
pydocstyle.convention = "google"
|
||||
pyupgrade.keep-runtime-typing = true
|
||||
flake8-annotations.allow-star-arg-any = true
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = [
|
||||
"D", # Documentation rules
|
||||
"PLC0415", # Imports should be at the top. Not always desirable for tests
|
||||
]
|
||||
|
||||
[tool.ruff.lint.extend-per-file-ignores]
|
||||
"scripts/check_imports.py" = ["ALL"]
|
||||
|
||||
"langchain/globals.py" = [
|
||||
"PLW"
|
||||
]
|
||||
|
||||
"langchain/chat_models/base.py" = [
|
||||
"ANN",
|
||||
"C901",
|
||||
"FIX002",
|
||||
"N802",
|
||||
"PLR0911",
|
||||
"PLR0912",
|
||||
"PLR0915",
|
||||
]
|
||||
|
||||
"langchain/embeddings/base.py" = [
|
||||
"PLR0911",
|
||||
"PLR0913",
|
||||
]
|
||||
|
||||
"tests/**/*.py" = [
|
||||
"S101", # Tests need assertions
|
||||
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
|
||||
"SLF001", # Private member access in tests
|
||||
"PLR2004", # Magic values are perfectly fine in unit tests (e.g. 0, 1, 2, etc.)
|
||||
"C901", # Too complex
|
||||
"ANN401", # Annotated type is not necessary
|
||||
"N802", # Function name should be lowercase
|
||||
"PLW1641", # Object does not implement __hash__ method
|
||||
"ARG002", # Unused argument
|
||||
"BLE001", # Do not catch blind exception
|
||||
"N801", # class name should use CapWords convention
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"scheduled: mark tests to run in scheduled testing",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
filterwarnings = [
|
||||
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
|
||||
"ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:tests",
|
||||
"ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:tests",
|
||||
]
|
31
libs/langchain_v1/scripts/check_imports.py
Normal file
31
libs/langchain_v1/scripts/check_imports.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Quickly verify that a list of Python files can be loaded by the Python interpreter
|
||||
without raising any errors. Ran before running more expensive tests. Useful in
|
||||
Makefiles.
|
||||
|
||||
If loading a file fails, the script prints the problematic filename and the detailed
|
||||
error traceback.
|
||||
"""
|
||||
|
||||
import random
|
||||
import string
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
module_name = "".join(
|
||||
random.choice(string.ascii_letters) # noqa: S311
|
||||
for _ in range(20)
|
||||
)
|
||||
SourceFileLoader(module_name, file).load_module()
|
||||
except Exception:
|
||||
has_failure = True
|
||||
print(file) # noqa: T201
|
||||
traceback.print_exc()
|
||||
print() # noqa: T201
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
1
libs/langchain_v1/tests/__init__.py
Normal file
1
libs/langchain_v1/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""All tests for this package."""
|
1
libs/langchain_v1/tests/integration_tests/__init__.py
Normal file
1
libs/langchain_v1/tests/integration_tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""All integration tests (tests that call out to an external API)."""
|
1
libs/langchain_v1/tests/integration_tests/cache/__init__.py
vendored
Normal file
1
libs/langchain_v1/tests/integration_tests/cache/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
"""All integration tests for Cache objects."""
|
81
libs/langchain_v1/tests/integration_tests/cache/fake_embeddings.py
vendored
Normal file
81
libs/langchain_v1/tests/integration_tests/cache/fake_embeddings.py
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Fake Embedding class for testing purposes."""
|
||||
|
||||
import math
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
fake_texts = ["foo", "bar", "baz"]
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[1.0] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
return [1.0] * 9 + [0.0]
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
"""Fake embeddings which remember all the texts seen so far to return consistent
|
||||
vectors for the same texts."""
|
||||
|
||||
def __init__(self, dimensionality: int = 10) -> None:
|
||||
self.known_texts: list[str] = []
|
||||
self.dimensionality = dimensionality
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Return consistent embeddings for each text seen so far."""
|
||||
out_vectors = []
|
||||
for text in texts:
|
||||
if text not in self.known_texts:
|
||||
self.known_texts.append(text)
|
||||
vector = [1.0] * (self.dimensionality - 1) + [
|
||||
float(self.known_texts.index(text)),
|
||||
]
|
||||
out_vectors.append(vector)
|
||||
return out_vectors
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown."""
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
|
||||
class AngularTwoDimensionalEmbeddings(Embeddings):
|
||||
"""
|
||||
From angles (as strings in units of pi) to unit embedding vectors on a circle.
|
||||
"""
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Make a list of texts into a list of embedding vectors.
|
||||
"""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""
|
||||
Convert input text to a 'vector' (list of floats).
|
||||
If the text is a number, use it as the angle for the
|
||||
unit vector in units of pi.
|
||||
Any other input text becomes the singular result [0, 0] !
|
||||
"""
|
||||
try:
|
||||
angle = float(text)
|
||||
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
|
||||
except ValueError:
|
||||
# Assume: just test string, no attention is paid to values.
|
||||
return [0.0, 0.0]
|
@@ -0,0 +1,59 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
|
||||
class multiply(BaseModel):
|
||||
"""Product of two ints."""
|
||||
|
||||
x: int
|
||||
y: int
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
|
||||
async def test_init_chat_model_chain() -> None:
|
||||
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
|
||||
model_with_tools = model.bind_tools([multiply])
|
||||
|
||||
model_with_config = model_with_tools.with_config(
|
||||
RunnableConfig(tags=["foo"]),
|
||||
configurable={"bar_model": "claude-3-sonnet-20240229"},
|
||||
)
|
||||
prompt = ChatPromptTemplate.from_messages([("system", "foo"), ("human", "{input}")])
|
||||
chain = prompt | model_with_config
|
||||
output = chain.invoke({"input": "bar"})
|
||||
assert isinstance(output, AIMessage)
|
||||
events = [
|
||||
event async for event in chain.astream_events({"input": "bar"}, version="v2")
|
||||
]
|
||||
assert events
|
||||
|
||||
|
||||
class TestStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> type[BaseChatModel]:
|
||||
return cast("type[BaseChatModel]", init_chat_model)
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "gpt-4o", "configurable_fields": "any"}
|
||||
|
||||
@property
|
||||
def supports_image_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_tool_calling(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_structured_output(self) -> bool:
|
||||
return True
|
34
libs/langchain_v1/tests/integration_tests/conftest.py
Normal file
34
libs/langchain_v1/tests/integration_tests/conftest.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Getting the absolute path of the current file's directory
|
||||
ABS_PATH = Path(__file__).resolve().parent
|
||||
|
||||
# Getting the absolute path of the project's root directory
|
||||
PROJECT_DIR = ABS_PATH.parent.parent
|
||||
|
||||
|
||||
# Loading the .env file if it exists
|
||||
def _load_env() -> None:
|
||||
dotenv_path = PROJECT_DIR / "tests" / "integration_tests" / ".env"
|
||||
if dotenv_path.exists():
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(dotenv_path)
|
||||
|
||||
|
||||
_load_env()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_dir() -> Path:
|
||||
return PROJECT_DIR / "tests" / "integration_tests"
|
||||
|
||||
|
||||
# This fixture returns a string containing the path to the cassette directory for the
|
||||
# current module
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_cassette_dir(request: pytest.FixtureRequest) -> str:
|
||||
module = Path(request.module.__file__)
|
||||
return str(module.parent / "cassettes" / module.stem)
|
@@ -0,0 +1,44 @@
|
||||
"""Test embeddings base module."""
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "model"),
|
||||
[
|
||||
("openai", "text-embedding-3-large"),
|
||||
("google_vertexai", "text-embedding-gecko@003"),
|
||||
("bedrock", "amazon.titan-embed-text-v1"),
|
||||
("cohere", "embed-english-v2.0"),
|
||||
],
|
||||
)
|
||||
async def test_init_embedding_model(provider: str, model: str) -> None:
|
||||
package = _SUPPORTED_PROVIDERS[provider]
|
||||
try:
|
||||
importlib.import_module(package)
|
||||
except ImportError:
|
||||
pytest.skip(f"Package {package} is not installed")
|
||||
|
||||
model_colon = init_embeddings(f"{provider}:{model}")
|
||||
assert isinstance(model_colon, Embeddings)
|
||||
|
||||
model_explicit = init_embeddings(
|
||||
model=model,
|
||||
provider=provider,
|
||||
)
|
||||
assert isinstance(model_explicit, Embeddings)
|
||||
|
||||
text = "Hello world"
|
||||
|
||||
embedding_colon = await model_colon.aembed_query(text)
|
||||
assert isinstance(embedding_colon, list)
|
||||
assert all(isinstance(x, float) for x in embedding_colon)
|
||||
|
||||
embedding_explicit = await model_explicit.aembed_query(text)
|
||||
assert isinstance(embedding_explicit, list)
|
||||
assert all(isinstance(x, float) for x in embedding_explicit)
|
@@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
0
libs/langchain_v1/tests/unit_tests/__init__.py
Normal file
0
libs/langchain_v1/tests/unit_tests/__init__.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSequence
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain.chat_models import __all__, init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"init_chat_model",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
"""Test that all expected imports are present in the module's __all__."""
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
|
||||
|
||||
@pytest.mark.requires(
|
||||
"langchain_openai",
|
||||
"langchain_anthropic",
|
||||
"langchain_fireworks",
|
||||
"langchain_groq",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "model_provider"),
|
||||
[
|
||||
("gpt-4o", "openai"),
|
||||
("claude-3-opus-20240229", "anthropic"),
|
||||
("accounts/fireworks/models/mixtral-8x7b-instruct", "fireworks"),
|
||||
("mixtral-8x7b-32768", "groq"),
|
||||
],
|
||||
)
|
||||
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
|
||||
llm1: BaseChatModel = init_chat_model(
|
||||
model_name,
|
||||
model_provider=model_provider,
|
||||
api_key="foo",
|
||||
)
|
||||
llm2: BaseChatModel = init_chat_model(
|
||||
f"{model_provider}:{model_name}",
|
||||
api_key="foo",
|
||||
)
|
||||
assert llm1.dict() == llm2.dict()
|
||||
|
||||
|
||||
def test_init_missing_dep() -> None:
|
||||
with pytest.raises(ImportError):
|
||||
init_chat_model("mixtral-8x7b-32768", model_provider="groq")
|
||||
|
||||
|
||||
def test_init_unknown_provider() -> None:
|
||||
with pytest.raises(ValueError, match="Unsupported model_provider='bar'."):
|
||||
init_chat_model("foo", model_provider="bar")
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"},
|
||||
clear=True,
|
||||
)
|
||||
def test_configurable() -> None:
|
||||
model = init_chat_model()
|
||||
|
||||
for method in (
|
||||
"invoke",
|
||||
"ainvoke",
|
||||
"batch",
|
||||
"abatch",
|
||||
"stream",
|
||||
"astream",
|
||||
"batch_as_completed",
|
||||
"abatch_as_completed",
|
||||
):
|
||||
assert hasattr(model, method)
|
||||
|
||||
# Doesn't have access non-configurable, non-declarative methods until a config is
|
||||
# provided.
|
||||
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(model, method)
|
||||
|
||||
# Can call declarative methods even without a default model.
|
||||
model_with_tools = model.bind_tools(
|
||||
[{"name": "foo", "description": "foo", "parameters": {}}],
|
||||
)
|
||||
|
||||
# Check that original model wasn't mutated by declarative operation.
|
||||
assert model._queued_declarative_operations == []
|
||||
|
||||
# Can iteratively call declarative methods.
|
||||
model_with_config = model_with_tools.with_config(
|
||||
RunnableConfig(tags=["foo"]),
|
||||
configurable={"model": "gpt-4o"},
|
||||
)
|
||||
assert model_with_config.model_name == "gpt-4o" # type: ignore[attr-defined]
|
||||
|
||||
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
|
||||
assert hasattr(model_with_config, method)
|
||||
|
||||
assert model_with_config.model_dump() == { # type: ignore[attr-defined]
|
||||
"name": None,
|
||||
"bound": {
|
||||
"name": None,
|
||||
"disable_streaming": False,
|
||||
"disabled_params": None,
|
||||
"model_name": "gpt-4o",
|
||||
"temperature": None,
|
||||
"model_kwargs": {},
|
||||
"openai_api_key": SecretStr("foo"),
|
||||
"openai_api_base": None,
|
||||
"openai_organization": None,
|
||||
"openai_proxy": None,
|
||||
"output_version": "v0",
|
||||
"request_timeout": None,
|
||||
"max_retries": None,
|
||||
"presence_penalty": None,
|
||||
"reasoning": None,
|
||||
"reasoning_effort": None,
|
||||
"frequency_penalty": None,
|
||||
"include": None,
|
||||
"seed": None,
|
||||
"service_tier": None,
|
||||
"logprobs": None,
|
||||
"top_logprobs": None,
|
||||
"logit_bias": None,
|
||||
"streaming": False,
|
||||
"n": None,
|
||||
"top_p": None,
|
||||
"truncation": None,
|
||||
"max_tokens": None,
|
||||
"tiktoken_model_name": None,
|
||||
"default_headers": None,
|
||||
"default_query": None,
|
||||
"stop": None,
|
||||
"store": None,
|
||||
"extra_body": None,
|
||||
"include_response_headers": False,
|
||||
"stream_usage": False,
|
||||
"use_previous_response_id": False,
|
||||
"use_responses_api": None,
|
||||
},
|
||||
"kwargs": {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "foo", "description": "foo", "parameters": {}},
|
||||
},
|
||||
],
|
||||
},
|
||||
"config": {"tags": ["foo"], "configurable": {}},
|
||||
"config_factories": [],
|
||||
"custom_input_type": None,
|
||||
"custom_output_type": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"},
|
||||
clear=True,
|
||||
)
|
||||
def test_configurable_with_default() -> None:
|
||||
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
|
||||
for method in (
|
||||
"invoke",
|
||||
"ainvoke",
|
||||
"batch",
|
||||
"abatch",
|
||||
"stream",
|
||||
"astream",
|
||||
"batch_as_completed",
|
||||
"abatch_as_completed",
|
||||
):
|
||||
assert hasattr(model, method)
|
||||
|
||||
# Does have access non-configurable, non-declarative methods since default params
|
||||
# are provided.
|
||||
for method in ("get_num_tokens", "get_num_tokens_from_messages", "dict"):
|
||||
assert hasattr(model, method)
|
||||
|
||||
assert model.model_name == "gpt-4o"
|
||||
|
||||
model_with_tools = model.bind_tools(
|
||||
[{"name": "foo", "description": "foo", "parameters": {}}],
|
||||
)
|
||||
|
||||
model_with_config = model_with_tools.with_config(
|
||||
RunnableConfig(tags=["foo"]),
|
||||
configurable={"bar_model": "claude-3-sonnet-20240229"},
|
||||
)
|
||||
|
||||
assert model_with_config.model == "claude-3-sonnet-20240229" # type: ignore[attr-defined]
|
||||
|
||||
assert model_with_config.model_dump() == { # type: ignore[attr-defined]
|
||||
"name": None,
|
||||
"bound": {
|
||||
"name": None,
|
||||
"disable_streaming": False,
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"mcp_servers": None,
|
||||
"max_tokens": 1024,
|
||||
"temperature": None,
|
||||
"thinking": None,
|
||||
"top_k": None,
|
||||
"top_p": None,
|
||||
"default_request_timeout": None,
|
||||
"max_retries": 2,
|
||||
"stop_sequences": None,
|
||||
"anthropic_api_url": "https://api.anthropic.com",
|
||||
"anthropic_api_key": SecretStr("bar"),
|
||||
"betas": None,
|
||||
"default_headers": None,
|
||||
"model_kwargs": {},
|
||||
"streaming": False,
|
||||
"stream_usage": True,
|
||||
},
|
||||
"kwargs": {
|
||||
"tools": [{"name": "foo", "description": "foo", "input_schema": {}}],
|
||||
},
|
||||
"config": {"tags": ["foo"], "configurable": {}},
|
||||
"config_factories": [],
|
||||
"custom_input_type": None,
|
||||
"custom_output_type": None,
|
||||
}
|
||||
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
|
||||
chain = prompt | model_with_config
|
||||
assert isinstance(chain, RunnableSequence)
|
122
libs/langchain_v1/tests/unit_tests/conftest.py
Normal file
122
libs/langchain_v1/tests/unit_tests/conftest.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Configuration for unit tests."""
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from importlib import util
|
||||
|
||||
import pytest
|
||||
from blockbuster import blockbuster_ctx
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def blockbuster() -> Iterator[None]:
|
||||
with blockbuster_ctx("langchain") as bb:
|
||||
bb.functions["io.TextIOWrapper.read"].can_block_in(
|
||||
"langchain/__init__.py",
|
||||
"<module>",
|
||||
)
|
||||
|
||||
for func in ["os.stat", "os.path.abspath"]:
|
||||
(
|
||||
bb.functions[func]
|
||||
.can_block_in("langchain_core/runnables/base.py", "__repr__")
|
||||
.can_block_in(
|
||||
"langchain_core/beta/runnables/context.py",
|
||||
"aconfig_with_context",
|
||||
)
|
||||
)
|
||||
|
||||
for func in ["os.stat", "io.TextIOWrapper.read"]:
|
||||
bb.functions[func].can_block_in(
|
||||
"langsmith/client.py",
|
||||
"_default_retry_config",
|
||||
)
|
||||
|
||||
for bb_function in bb.functions.values():
|
||||
bb_function.can_block_in(
|
||||
"freezegun/api.py",
|
||||
"_get_cached_module_attributes",
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
action="store_true",
|
||||
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--only-core",
|
||||
action="store_true",
|
||||
help="Only run core tests. Never runs any extended tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(
|
||||
config: pytest.Config, items: Sequence[pytest.Function]
|
||||
) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: dict[str, bool] = {}
|
||||
|
||||
only_extended = config.getoption("--only-extended") or False
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
msg = "Cannot specify both `--only-extended` and `--only-core`."
|
||||
raise ValueError(msg)
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
if only_core:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||
continue
|
||||
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
try:
|
||||
installed = util.find_spec(pkg) is not None
|
||||
except Exception:
|
||||
installed = False
|
||||
required_pkgs_info[pkg] = installed
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
if only_extended:
|
||||
pytest.fail(
|
||||
f"Package `{pkg}` is not installed but is required for "
|
||||
f"extended tests. Please install the given package and "
|
||||
f"try again.",
|
||||
)
|
||||
|
||||
else:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"),
|
||||
)
|
||||
break
|
||||
elif only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test."),
|
||||
)
|
111
libs/langchain_v1/tests/unit_tests/embeddings/test_base.py
Normal file
111
libs/langchain_v1/tests/unit_tests/embeddings/test_base.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Test embeddings base module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.embeddings.base import (
|
||||
_SUPPORTED_PROVIDERS,
|
||||
_infer_model_and_provider,
|
||||
_parse_model_string,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_model_string() -> None:
|
||||
"""Test parsing model strings into provider and model components."""
|
||||
assert _parse_model_string("openai:text-embedding-3-small") == (
|
||||
"openai",
|
||||
"text-embedding-3-small",
|
||||
)
|
||||
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
|
||||
"bedrock",
|
||||
"amazon.titan-embed-text-v1",
|
||||
)
|
||||
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
|
||||
"huggingface",
|
||||
"BAAI/bge-base-en:v1.5",
|
||||
)
|
||||
|
||||
|
||||
def test_parse_model_string_errors() -> None:
|
||||
"""Test error cases for model string parsing."""
|
||||
with pytest.raises(ValueError, match="Model name must be"):
|
||||
_parse_model_string("just-a-model-name")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid model format "):
|
||||
_parse_model_string("")
|
||||
|
||||
with pytest.raises(ValueError, match="is not supported"):
|
||||
_parse_model_string(":model-name")
|
||||
|
||||
with pytest.raises(ValueError, match="Model name cannot be empty"):
|
||||
_parse_model_string("openai:")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Provider 'invalid-provider' is not supported",
|
||||
):
|
||||
_parse_model_string("invalid-provider:model-name")
|
||||
|
||||
for provider in _SUPPORTED_PROVIDERS:
|
||||
with pytest.raises(ValueError, match=f"{provider}"):
|
||||
_parse_model_string("invalid-provider:model-name")
|
||||
|
||||
|
||||
def test_infer_model_and_provider() -> None:
|
||||
"""Test model and provider inference from different input formats."""
|
||||
assert _infer_model_and_provider("openai:text-embedding-3-small") == (
|
||||
"openai",
|
||||
"text-embedding-3-small",
|
||||
)
|
||||
|
||||
assert _infer_model_and_provider(
|
||||
model="text-embedding-3-small",
|
||||
provider="openai",
|
||||
) == ("openai", "text-embedding-3-small")
|
||||
|
||||
assert _infer_model_and_provider(
|
||||
model="ft:text-embedding-3-small",
|
||||
provider="openai",
|
||||
) == ("openai", "ft:text-embedding-3-small")
|
||||
|
||||
assert _infer_model_and_provider(model="openai:ft:text-embedding-3-small") == (
|
||||
"openai",
|
||||
"ft:text-embedding-3-small",
|
||||
)
|
||||
|
||||
|
||||
def test_infer_model_and_provider_errors() -> None:
|
||||
"""Test error cases for model and provider inference."""
|
||||
# Test missing provider
|
||||
with pytest.raises(ValueError, match="Must specify either"):
|
||||
_infer_model_and_provider("text-embedding-3-small")
|
||||
|
||||
# Test empty model
|
||||
with pytest.raises(ValueError, match="Model name cannot be empty"):
|
||||
_infer_model_and_provider("")
|
||||
|
||||
# Test empty provider with model
|
||||
with pytest.raises(ValueError, match="Must specify either"):
|
||||
_infer_model_and_provider("model", provider="")
|
||||
|
||||
# Test invalid provider
|
||||
with pytest.raises(ValueError, match="Provider 'invalid' is not supported.") as exc:
|
||||
_infer_model_and_provider("model", provider="invalid")
|
||||
# Test provider list is in error
|
||||
for provider in _SUPPORTED_PROVIDERS:
|
||||
assert provider in str(exc.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider",
|
||||
sorted(_SUPPORTED_PROVIDERS.keys()),
|
||||
)
|
||||
def test_supported_providers_package_names(provider: str) -> None:
|
||||
"""Test that all supported providers have valid package names."""
|
||||
package = _SUPPORTED_PROVIDERS[provider]
|
||||
assert "-" not in package
|
||||
assert package.startswith("langchain_")
|
||||
assert package.islower()
|
||||
|
||||
|
||||
def test_is_sorted() -> None:
|
||||
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())
|
250
libs/langchain_v1/tests/unit_tests/embeddings/test_caching.py
Normal file
250
libs/langchain_v1/tests/unit_tests/embeddings/test_caching.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Embeddings tests."""
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import importlib
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain.embeddings import CacheBackedEmbeddings
|
||||
from langchain.storage.in_memory import InMemoryStore
|
||||
|
||||
|
||||
class MockEmbeddings(Embeddings):
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
# Simulate embedding documents
|
||||
embeddings: list[list[float]] = []
|
||||
for text in texts:
|
||||
if text == "RAISE_EXCEPTION":
|
||||
msg = "Simulated embedding failure"
|
||||
raise ValueError(msg)
|
||||
embeddings.append([len(text), len(text) + 1])
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
# Simulate embedding a query
|
||||
return [5.0, 6.0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_embeddings() -> CacheBackedEmbeddings:
|
||||
"""Create a cache backed embeddings."""
|
||||
store = InMemoryStore()
|
||||
embeddings = MockEmbeddings()
|
||||
return CacheBackedEmbeddings.from_bytes_store(
|
||||
embeddings,
|
||||
store,
|
||||
namespace="test_namespace",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_embeddings_batch() -> CacheBackedEmbeddings:
|
||||
"""Create a cache backed embeddings with a batch_size of 3."""
|
||||
store = InMemoryStore()
|
||||
embeddings = MockEmbeddings()
|
||||
return CacheBackedEmbeddings.from_bytes_store(
|
||||
embeddings,
|
||||
store,
|
||||
namespace="test_namespace",
|
||||
batch_size=3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_embeddings_with_query() -> CacheBackedEmbeddings:
|
||||
"""Create a cache backed embeddings with query caching."""
|
||||
doc_store = InMemoryStore()
|
||||
query_store = InMemoryStore()
|
||||
embeddings = MockEmbeddings()
|
||||
return CacheBackedEmbeddings.from_bytes_store(
|
||||
embeddings,
|
||||
document_embedding_cache=doc_store,
|
||||
namespace="test_namespace",
|
||||
query_embedding_cache=query_store,
|
||||
)
|
||||
|
||||
|
||||
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
texts = ["1", "22", "a", "333"]
|
||||
vectors = cache_embeddings.embed_documents(texts)
|
||||
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
|
||||
assert vectors == expected_vectors
|
||||
keys = list(cache_embeddings.document_embedding_store.yield_keys())
|
||||
assert len(keys) == 4
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
|
||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
with contextlib.suppress(ValueError):
|
||||
cache_embeddings_batch.embed_documents(texts)
|
||||
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
|
||||
# only the first batch of three embeddings should exist
|
||||
assert len(keys) == 3
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
text = "query_text"
|
||||
vector = cache_embeddings.embed_query(text)
|
||||
expected_vector = [5.0, 6.0]
|
||||
assert vector == expected_vector
|
||||
assert cache_embeddings.query_embedding_store is None
|
||||
|
||||
|
||||
def test_embed_cached_query(cache_embeddings_with_query: CacheBackedEmbeddings) -> None:
|
||||
text = "query_text"
|
||||
vector = cache_embeddings_with_query.embed_query(text)
|
||||
expected_vector = [5.0, 6.0]
|
||||
assert vector == expected_vector
|
||||
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
|
||||
assert len(keys) == 1
|
||||
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
|
||||
|
||||
|
||||
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
texts = ["1", "22", "a", "333"]
|
||||
vectors = await cache_embeddings.aembed_documents(texts)
|
||||
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
|
||||
assert vectors == expected_vectors
|
||||
keys = [
|
||||
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
|
||||
]
|
||||
assert len(keys) == 4
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
async def test_aembed_documents_batch(
|
||||
cache_embeddings_batch: CacheBackedEmbeddings,
|
||||
) -> None:
|
||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
with contextlib.suppress(ValueError):
|
||||
await cache_embeddings_batch.aembed_documents(texts)
|
||||
keys = [
|
||||
key
|
||||
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
||||
]
|
||||
# only the first batch of three embeddings should exist
|
||||
assert len(keys) == 3
|
||||
# UUID is expected to be the same for the same text
|
||||
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
|
||||
|
||||
|
||||
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
text = "query_text"
|
||||
vector = await cache_embeddings.aembed_query(text)
|
||||
expected_vector = [5.0, 6.0]
|
||||
assert vector == expected_vector
|
||||
|
||||
|
||||
async def test_aembed_query_cached(
|
||||
cache_embeddings_with_query: CacheBackedEmbeddings,
|
||||
) -> None:
|
||||
text = "query_text"
|
||||
await cache_embeddings_with_query.aembed_query(text)
|
||||
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
|
||||
assert len(keys) == 1
|
||||
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
|
||||
|
||||
|
||||
def test_blake2b_encoder() -> None:
|
||||
"""Test that the blake2b encoder is used to encode keys in the cache store."""
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
cbe = CacheBackedEmbeddings.from_bytes_store(
|
||||
emb,
|
||||
store,
|
||||
namespace="ns_",
|
||||
key_encoder="blake2b",
|
||||
)
|
||||
|
||||
text = "blake"
|
||||
cbe.embed_documents([text])
|
||||
|
||||
# rebuild the key exactly as the library does
|
||||
expected_key = "ns_" + hashlib.blake2b(text.encode()).hexdigest()
|
||||
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
|
||||
|
||||
|
||||
def test_sha256_encoder() -> None:
|
||||
"""Test that the sha256 encoder is used to encode keys in the cache store."""
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
cbe = CacheBackedEmbeddings.from_bytes_store(
|
||||
emb,
|
||||
store,
|
||||
namespace="ns_",
|
||||
key_encoder="sha256",
|
||||
)
|
||||
|
||||
text = "foo"
|
||||
cbe.embed_documents([text])
|
||||
|
||||
# rebuild the key exactly as the library does
|
||||
expected_key = "ns_" + hashlib.sha256(text.encode()).hexdigest()
|
||||
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
|
||||
|
||||
|
||||
def test_sha512_encoder() -> None:
|
||||
"""Test that the sha512 encoder is used to encode keys in the cache store."""
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
cbe = CacheBackedEmbeddings.from_bytes_store(
|
||||
emb,
|
||||
store,
|
||||
namespace="ns_",
|
||||
key_encoder="sha512",
|
||||
)
|
||||
|
||||
text = "foo"
|
||||
cbe.embed_documents([text])
|
||||
|
||||
# rebuild the key exactly as the library does
|
||||
expected_key = "ns_" + hashlib.sha512(text.encode()).hexdigest()
|
||||
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
|
||||
|
||||
|
||||
def test_sha1_warning_emitted_once() -> None:
|
||||
"""Test that a warning is emitted when using SHA-1 as the default key encoder."""
|
||||
module = importlib.import_module(CacheBackedEmbeddings.__module__)
|
||||
|
||||
# Create a *temporary* MonkeyPatch object whose effects disappear
|
||||
# automatically when the with-block exits.
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
# We're monkey patching the module to reset the `_warned_about_sha1` flag
|
||||
# which may have been set while testing other parts of the codebase.
|
||||
mp.setattr(module, "_warned_about_sha1", False, raising=False)
|
||||
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
CacheBackedEmbeddings.from_bytes_store(emb, store) # triggers warning
|
||||
CacheBackedEmbeddings.from_bytes_store(emb, store) # silent
|
||||
|
||||
sha1_msgs = [w for w in caught if "SHA-1" in str(w.message)]
|
||||
assert len(sha1_msgs) == 1
|
||||
|
||||
|
||||
def test_custom_encoder() -> None:
|
||||
"""Test that a custom encoder can be used to encode keys in the cache store."""
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
|
||||
def custom_upper(text: str) -> str: # very simple demo encoder
|
||||
return "CUSTOM_" + text.upper()
|
||||
|
||||
cbe = CacheBackedEmbeddings.from_bytes_store(emb, store, key_encoder=custom_upper)
|
||||
txt = "x"
|
||||
cbe.embed_documents([txt])
|
||||
|
||||
assert list(cbe.document_embedding_store.yield_keys()) == ["CUSTOM_X"]
|
@@ -0,0 +1,10 @@
|
||||
from langchain import embeddings
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"CacheBackedEmbeddings",
|
||||
"init_embeddings",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(embeddings.__all__) == set(EXPECTED_ALL)
|
1
libs/langchain_v1/tests/unit_tests/prompts/__init__.py
Normal file
1
libs/langchain_v1/tests/unit_tests/prompts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test prompt functionality."""
|
24
libs/langchain_v1/tests/unit_tests/prompts/test_imports.py
Normal file
24
libs/langchain_v1/tests/unit_tests/prompts/test_imports.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from langchain import prompts
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AIMessagePromptTemplate",
|
||||
"BaseChatPromptTemplate",
|
||||
"BasePromptTemplate",
|
||||
"ChatMessagePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"FewShotPromptTemplate",
|
||||
"FewShotPromptWithTemplates",
|
||||
"HumanMessagePromptTemplate",
|
||||
"LengthBasedExampleSelector",
|
||||
"MaxMarginalRelevanceExampleSelector",
|
||||
"MessagesPlaceholder",
|
||||
"PromptTemplate",
|
||||
"SemanticSimilarityExampleSelector",
|
||||
"StringPromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(prompts.__all__) == set(EXPECTED_ALL)
|
155
libs/langchain_v1/tests/unit_tests/storage/test_filesystem.py
Normal file
155
libs/langchain_v1/tests/unit_tests/storage/test_filesystem.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langchain_core.stores import InvalidKeyException
|
||||
|
||||
from langchain.storage.file_system import LocalFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_store() -> Generator[LocalFileStore, None, None]:
|
||||
# Create a temporary directory for testing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Instantiate the LocalFileStore with the temporary directory as the root path
|
||||
store = LocalFileStore(temp_dir)
|
||||
yield store
|
||||
|
||||
|
||||
def test_mset_and_mget(file_store: LocalFileStore) -> None:
|
||||
# Set values for keys
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
file_store.mset(key_value_pairs)
|
||||
|
||||
# Get values for keys
|
||||
values = file_store.mget(["key1", "key2"])
|
||||
|
||||
# Assert that the retrieved values match the original values
|
||||
assert values == [b"value1", b"value2"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("chmod_dir_s", "chmod_file_s"),
|
||||
[("777", "666"), ("770", "660"), ("700", "600")],
|
||||
)
|
||||
def test_mset_chmod(chmod_dir_s: str, chmod_file_s: str) -> None:
|
||||
chmod_dir = int(chmod_dir_s, base=8)
|
||||
chmod_file = int(chmod_file_s, base=8)
|
||||
|
||||
# Create a temporary directory for testing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Instantiate the LocalFileStore with a directory inside the temporary directory
|
||||
# as the root path
|
||||
file_store = LocalFileStore(
|
||||
Path(temp_dir) / "store_dir",
|
||||
chmod_dir=chmod_dir,
|
||||
chmod_file=chmod_file,
|
||||
)
|
||||
|
||||
# Set values for keys
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
file_store.mset(key_value_pairs)
|
||||
|
||||
# verify the permissions are set correctly
|
||||
# (test only the standard user/group/other bits)
|
||||
dir_path = file_store.root_path
|
||||
file_path = file_store.root_path / "key1"
|
||||
assert (dir_path.stat().st_mode & 0o777) == chmod_dir
|
||||
assert (file_path.stat().st_mode & 0o777) == chmod_file
|
||||
|
||||
|
||||
def test_mget_update_atime() -> None:
|
||||
# Create a temporary directory for testing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Instantiate the LocalFileStore with a directory inside the temporary directory
|
||||
# as the root path
|
||||
file_store = LocalFileStore(Path(temp_dir) / "store_dir", update_atime=True)
|
||||
|
||||
# Set values for keys
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
file_store.mset(key_value_pairs)
|
||||
|
||||
# Get original access time
|
||||
file_path = file_store.root_path / "key1"
|
||||
atime1 = file_path.stat().st_atime
|
||||
|
||||
# Get values for keys
|
||||
_ = file_store.mget(["key1", "key2"])
|
||||
|
||||
# Make sure the filesystem access time has been updated
|
||||
atime2 = file_path.stat().st_atime
|
||||
assert atime2 != atime1
|
||||
|
||||
|
||||
def test_mdelete(file_store: LocalFileStore) -> None:
|
||||
# Set values for keys
|
||||
key_value_pairs = [("key1", b"value1"), ("key2", b"value2")]
|
||||
file_store.mset(key_value_pairs)
|
||||
|
||||
# Delete keys
|
||||
file_store.mdelete(["key1"])
|
||||
|
||||
# Check if the deleted key is present
|
||||
values = file_store.mget(["key1"])
|
||||
|
||||
# Assert that the value is None after deletion
|
||||
assert values == [None]
|
||||
|
||||
|
||||
def test_set_invalid_key(file_store: LocalFileStore) -> None:
|
||||
"""Test that an exception is raised when an invalid key is set."""
|
||||
# Set a key-value pair
|
||||
key = "crying-cat/😿"
|
||||
value = b"This is a test value"
|
||||
with pytest.raises(InvalidKeyException):
|
||||
file_store.mset([(key, value)])
|
||||
|
||||
|
||||
def test_set_key_and_verify_content(file_store: LocalFileStore) -> None:
|
||||
"""Test that the content of the file is the same as the value set."""
|
||||
# Set a key-value pair
|
||||
key = "test_key"
|
||||
value = b"This is a test value"
|
||||
file_store.mset([(key, value)])
|
||||
|
||||
# Verify the content of the actual file
|
||||
full_path = file_store._get_full_path(key)
|
||||
assert full_path.exists()
|
||||
assert full_path.read_bytes() == b"This is a test value"
|
||||
|
||||
|
||||
def test_yield_keys(file_store: LocalFileStore) -> None:
|
||||
# Set values for keys
|
||||
key_value_pairs = [("key1", b"value1"), ("subdir/key2", b"value2")]
|
||||
file_store.mset(key_value_pairs)
|
||||
|
||||
# Iterate over keys
|
||||
keys = list(file_store.yield_keys())
|
||||
|
||||
# Assert that the yielded keys match the expected keys
|
||||
expected_keys = ["key1", str(Path("subdir") / "key2")]
|
||||
assert keys == expected_keys
|
||||
|
||||
|
||||
def test_catches_forbidden_keys(file_store: LocalFileStore) -> None:
|
||||
"""Make sure we raise exception on keys that are not allowed; e.g., absolute path"""
|
||||
with pytest.raises(InvalidKeyException):
|
||||
file_store.mset([("/etc", b"value1")])
|
||||
with pytest.raises(InvalidKeyException):
|
||||
list(file_store.yield_keys("/etc/passwd"))
|
||||
with pytest.raises(InvalidKeyException):
|
||||
file_store.mget(["/etc/passwd"])
|
||||
|
||||
# check relative paths
|
||||
with pytest.raises(InvalidKeyException):
|
||||
list(file_store.yield_keys(".."))
|
||||
|
||||
with pytest.raises(InvalidKeyException):
|
||||
file_store.mget(["../etc/passwd"])
|
||||
|
||||
with pytest.raises(InvalidKeyException):
|
||||
file_store.mset([("../etc", b"value1")])
|
||||
|
||||
with pytest.raises(InvalidKeyException):
|
||||
list(file_store.yield_keys("../etc/passwd"))
|
15
libs/langchain_v1/tests/unit_tests/storage/test_imports.py
Normal file
15
libs/langchain_v1/tests/unit_tests/storage/test_imports.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from langchain import storage
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"EncoderBackedStore",
|
||||
"InMemoryStore",
|
||||
"InMemoryByteStore",
|
||||
"LocalFileStore",
|
||||
"InvalidKeyException",
|
||||
"create_lc_store",
|
||||
"create_kv_docstore",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(storage.__all__) == set(EXPECTED_ALL)
|
37
libs/langchain_v1/tests/unit_tests/storage/test_lc_store.py
Normal file
37
libs/langchain_v1/tests/unit_tests/storage/test_lc_store.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.storage._lc_store import create_kv_docstore, create_lc_store
|
||||
from langchain.storage.file_system import LocalFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_store() -> Generator[LocalFileStore, None, None]:
|
||||
# Create a temporary directory for testing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Instantiate the LocalFileStore with the temporary directory as the root path
|
||||
store = LocalFileStore(temp_dir)
|
||||
yield store
|
||||
|
||||
|
||||
def test_create_lc_store(file_store: LocalFileStore) -> None:
|
||||
"""Test that a docstore is created from a base store."""
|
||||
docstore = create_lc_store(file_store)
|
||||
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
|
||||
fetched_doc = cast("Document", docstore.mget(["key1"])[0])
|
||||
assert fetched_doc.page_content == "hello"
|
||||
assert fetched_doc.metadata == {"key": "value"}
|
||||
|
||||
|
||||
def test_create_kv_store(file_store: LocalFileStore) -> None:
|
||||
"""Test that a docstore is created from a base store."""
|
||||
docstore = create_kv_docstore(file_store)
|
||||
docstore.mset([("key1", Document(page_content="hello", metadata={"key": "value"}))])
|
||||
fetched_doc = docstore.mget(["key1"])[0]
|
||||
assert isinstance(fetched_doc, Document)
|
||||
assert fetched_doc.page_content == "hello"
|
||||
assert fetched_doc.metadata == {"key": "value"}
|
46
libs/langchain_v1/tests/unit_tests/stubs.py
Normal file
46
libs/langchain_v1/tests/unit_tests/stubs.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||
|
||||
|
||||
class AnyStr(str):
|
||||
__slots__ = ()
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, str)
|
||||
|
||||
|
||||
# The code below creates version of pydantic models
|
||||
# that will work in unit tests with AnyStr as id field
|
||||
# Please note that the `id` field is assigned AFTER the model is created
|
||||
# to workaround an issue with pydantic ignoring the __eq__ method on
|
||||
# subclassed strings.
|
||||
|
||||
|
||||
def _AnyIdDocument(**kwargs: Any) -> Document:
|
||||
"""Create a document with an id field."""
|
||||
message = Document(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
|
||||
"""Create ai message with an any id field."""
|
||||
message = AIMessageChunk(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
||||
|
||||
|
||||
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
|
||||
"""Create a human with an any id field."""
|
||||
message = HumanMessage(**kwargs)
|
||||
message.id = AnyStr()
|
||||
return message
|
35
libs/langchain_v1/tests/unit_tests/test_dependencies.py
Normal file
35
libs/langchain_v1/tests/unit_tests/test_dependencies.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""A unit test meant to catch accidental introduction of non-optional dependencies."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import toml
|
||||
from packaging.requirements import Requirement
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
|
||||
PYPROJECT_TOML = HERE / "../../pyproject.toml"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def uv_conf() -> dict[str, Any]:
|
||||
"""Load the pyproject.toml file."""
|
||||
with PYPROJECT_TOML.open() as f:
|
||||
return toml.load(f)
|
||||
|
||||
|
||||
def test_required_dependencies(uv_conf: Mapping[str, Any]) -> None:
|
||||
"""A test that checks if a new non-optional dependency is being introduced.
|
||||
|
||||
If this test is triggered, it means that a contributor is trying to introduce a new
|
||||
required dependency. This should be avoided in most situations.
|
||||
"""
|
||||
# Get the dependencies from the [tool.poetry.dependencies] section
|
||||
dependencies = uv_conf["project"]["dependencies"]
|
||||
required_dependencies = {Requirement(dep).name for dep in dependencies}
|
||||
|
||||
assert sorted(required_dependencies) == sorted(
|
||||
["langchain-core", "langchain-text-splitters", "langgraph", "pydantic"]
|
||||
)
|
60
libs/langchain_v1/tests/unit_tests/test_imports.py
Normal file
60
libs/langchain_v1/tests/unit_tests/test_imports.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import importlib
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# Attempt to recursively import all modules in langchain
|
||||
PKG_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
def test_import_all() -> None:
|
||||
"""Generate the public API for this package."""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning)
|
||||
library_code = PKG_ROOT / "langchain"
|
||||
for path in library_code.rglob("*.py"):
|
||||
# Calculate the relative path to the module
|
||||
module_name = (
|
||||
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
)
|
||||
if module_name.endswith("__init__"):
|
||||
# Without init
|
||||
module_name = module_name.rsplit(".", 1)[0]
|
||||
|
||||
mod = importlib.import_module(module_name)
|
||||
|
||||
all_attrs = getattr(mod, "__all__", [])
|
||||
|
||||
for name in all_attrs:
|
||||
# Attempt to import the name from the module
|
||||
try:
|
||||
obj = getattr(mod, name)
|
||||
assert obj is not None
|
||||
except Exception as e:
|
||||
msg = f"Could not import {module_name}.{name}"
|
||||
raise AssertionError(msg) from e
|
||||
|
||||
|
||||
def test_import_all_using_dir() -> None:
|
||||
"""Generate the public API for this package."""
|
||||
library_code = PKG_ROOT / "langchain"
|
||||
for path in library_code.rglob("*.py"):
|
||||
# Calculate the relative path to the module
|
||||
module_name = (
|
||||
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
|
||||
)
|
||||
if module_name.endswith("__init__"):
|
||||
# Without init
|
||||
module_name = module_name.rsplit(".", 1)[0]
|
||||
|
||||
try:
|
||||
mod = importlib.import_module(module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
msg = f"Could not import {module_name}"
|
||||
raise ModuleNotFoundError(msg) from e
|
||||
attributes = dir(mod)
|
||||
|
||||
for name in attributes:
|
||||
if name.strip().startswith("_"):
|
||||
continue
|
||||
# Attempt to import the name from the module
|
||||
getattr(mod, name)
|
11
libs/langchain_v1/tests/unit_tests/test_pytest_config.py
Normal file
11
libs/langchain_v1/tests/unit_tests/test_pytest_config.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import pytest
|
||||
import pytest_socket
|
||||
import requests
|
||||
|
||||
|
||||
def test_socket_disabled() -> None:
|
||||
"""This test should fail."""
|
||||
with pytest.raises(pytest_socket.SocketBlockedError):
|
||||
# Ignore S113 since we don't need a timeout here as the request
|
||||
# should fail immediately
|
||||
requests.get("https://www.example.com") # noqa: S113
|
14
libs/langchain_v1/tests/unit_tests/tools/test_imports.py
Normal file
14
libs/langchain_v1/tests/unit_tests/tools/test_imports.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from langchain import tools
|
||||
|
||||
EXPECTED_ALL = {
|
||||
"BaseTool",
|
||||
"InjectedToolArg",
|
||||
"InjectedToolCallId",
|
||||
"Tool",
|
||||
"ToolException",
|
||||
"tool",
|
||||
}
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(tools.__all__) == EXPECTED_ALL
|
4562
libs/langchain_v1/uv.lock
generated
Normal file
4562
libs/langchain_v1/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user